jeremiecoullon / SGMCMCJax

Lightweight library of stochastic gradient MCMC algorithms written in JAX.
https://sgmcmcjax.readthedocs.io/en/latest/index.html
Apache License 2.0
95 stars 8 forks source link

sghmc slow to compile for flax CNN #13

Closed jeremiecoullon closed 3 years ago

jeremiecoullon commented 3 years ago

sghmc is really slow to compile for the flax CNN example (using MNIST) in the docs.

This might be because the leapfrogs steps are done using a jax scan rather than a python loop

jeremiecoullon commented 3 years ago

Added a compiled_leapfrog argument for the sghmc kernel