TuringLang / JuliaBUGS.jl

A domain specific language (DSL) for probabilistic graphical models
https://turinglang.org/JuliaBUGS.jl/
MIT License
21 stars 3 forks source link

Compile to JAX to enable GPU/TPU acceleration and `vmap`. #209

Open yebai opened 2 months ago

sunxd3 commented 2 months ago

I really want to make this work, will spend some time and try to produce a prototype soon

yebai commented 2 months ago

If we utilise numpyro distributions, this looks quite doable: https://num.pyro.ai/en/stable/distributions.html

sunxd3 commented 2 months ago

tensorflow prob and just plain jax.random are also good

yebai commented 2 months ago

jax.random only provides samplers for common distributions.

DeepMind's distrax reimplemented TFP in native JAX.