rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Add leading singleton dimensions to batch shape #85

Closed rlouf closed 3 years ago

rlouf commented 3 years ago

Fixes #83. I need to go through all distributions and their tests.

rlouf commented 3 years ago

This PR changes the original design of the shapes in distribution, which consisted in assigning a batch shape of 1 for single number inputs. On top of not always being practical this leads to broadcasting issues for multivariate distribution. We thus took a different approach and instead of promoting the shape of the output of distributions we promote the shape of the parameters. As a result we also had to make a small modification in sample_forward as it changed the output shape.

@sidravi1 I had to change the initialization of MvNormal because of tihs.