Closed mjo22 closed 7 months ago
Computing log_prob batches raises the question of shape/vmap and the independence structure of inference.
Probabilistic programming frameworks like pyro have ways to do this book keeping (batch dimension, plates) https://pyro.ai/examples/tensor_shapes.html https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/
This makes sense. This may be reason to not include things like this in the library! What I wrote is easy enough to do externally, and in general we should stay away from assuming how exactly a user might want to do this. A good alternative would be to write up good notebook examples in the docs.
Closing as I think this is no longer the way to go. It is easy enough to just vmap in a users specific workflow.
I've been thinking more about how we might build batching into the library, specifically in order to compute likelihood. If there is interest, a new class in
cryojax.inference
would be easy to implement that would compute a batch of log-likelihood values.This is a good starting point for a skeleton I think. This uses utilities I've recently written in
cryojax.core
(see here for a tutorial).I’m not totally sure if this is something that should go in the library because this
LogLikelihoodBatcher
class is a function that’s only a few lines (I would even recommend people try to write things this themselves). But it depends on people’s interest and the direction of the library!