Closed jeremiecoullon closed 3 years ago
sghmc is really slow to compile for the flax CNN example (using MNIST) in the docs.
This might be because the leapfrogs steps are done using a jax scan rather than a python loop
Added a compiled_leapfrog argument for the sghmc kernel
compiled_leapfrog
sghmc is really slow to compile for the flax CNN example (using MNIST) in the docs.
This might be because the leapfrogs steps are done using a jax scan rather than a python loop