Open yebai opened 2 months ago
I really want to make this work, will spend some time and try to produce a prototype soon
If we utilise numpyro distributions, this looks quite doable: https://num.pyro.ai/en/stable/distributions.html
numpyro
tensorflow prob and just plain jax.random are also good
jax.random only provides samplers for common distributions.
jax.random
DeepMind's distrax reimplemented TFP in native JAX.
I really want to make this work, will spend some time and try to produce a prototype soon