rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
324 stars 17 forks source link

Bijectors #48

Closed rlouf closed 3 years ago

rlouf commented 4 years ago

I am convinced (change my mind!) that a PPL needs to own the whole chain from distributions to inference algorithms and bijectors. We will thus need to implement bijectors in mcx. We would like the implementation to be close to that of tensorflow-probability and Bijector.jl in the sense that it treats constrained-to-unconstrained and more complex (possibly parametrized) bijectors identically.

I open this issue to start thinking about the API and how bijectors would integrate with the rest of the library and how users would interact with them.

rlouf commented 4 years ago

We can start with something very simple, but that will probably not cover normalizing flows:

  1. Refer to bijectors as their forward transformation. Implement as a function;
  2. Write a Jaxpr that inverses functions from a registry of fb-inverse fn;
  3. Keep a registry support -> transform

Since JAX has a function to compute the Jacobian the generated code for every transformation will look very similar. No necessity to even create classes.