rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Batch shape shouldn’t default to 1 #83

Closed rlouf closed 3 years ago

rlouf commented 3 years ago

We currently default batch_shape to 1 to avoid broadcast error when sampling from prior distributions. However this can introduce errors, e.g. when arrays are initialized with the realized value of other random variables.

Instead of promoting the output of sample we should promote the shape of simple numbers to (1,) when initializing the distributions.