rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Core & custom Jaxprs #62

Open rlouf opened 3 years ago

rlouf commented 3 years ago

Here are a few ideas for the core that I won't have the time to do before the first release but have the potential to make MCX more general and greatly simplify the bijectors API.

First, a shortcoming of many PPLs is the impossibility of conditioning on deterministic transformations of random variables. This is because the logpdf function would need to propagate the inverse of the log-determinant of the jacobian matrix for volume conservation. This seems to be a job for Jaxprs. The idea would be to have the core compile the graph in a way that can be manipulated by JAX and create a "logpdf" Jaxpr that is applied on this function.

Then, if this works, we could only implement the "forward" part of bijectors. The logpdf Jaxpr would automatically take care of conserving the volume. Writing a Jaxpr that inverses the transformation is, if not easy, possible.

Graph --> JAX-ready logpdf --> logpdf Jaxpr
Graph --> Joint distribution forward sampler
Graph --> Predictive distribution sampler

Sampling and sampling predictive are simple enough that they can be left as is.

As a result we would have a two-layer core: