mjo22 / cryojax

Cryo electron microscopy image simulation and analysis built on JAX.
https://mjo22.github.io/cryojax/
GNU Lesser General Public License v2.1
29 stars 9 forks source link

Probabilistic programming library support #10

Closed mjo22 closed 1 year ago

mjo22 commented 1 year ago

cryojax should integrate with the JAX bayesian inference ecosystem. For example, it would be great to have support for probabilistic programming libraries such as pymc and numpyro.

For example, tinygp adds numpyro support by subclassing a Distribution: https://github.com/dfm/tinygp/blob/main/src/tinygp/numpyro_support.py.

Also, pymc has instructions for wrapping JAX functions: https://www.pymc.io/projects/examples/en/latest/case_studies/wrapping_jax_function.html

mjo22 commented 1 year ago

We should not tie ourselves to any one framework or add something like a PPL as a dependency. In order to use a PPL, one should, say, sample from a numpyro Distribution and use this to build a model.

There might at some point be an argument for building numpyro support by subclassing Distributions.