kazewong / flowMC

Normalizing-flow enhanced sampling package for probabilistic inference in Jax
https://flowmc.readthedocs.io/en/main/
MIT License
192 stars 23 forks source link

Slow compilation of RQSpline flow operations #47

Closed kazewong closed 1 year ago

kazewong commented 2 years ago

Currently the training and sampling phases takes around 3.5 minutes on a A100 node with icelake CPU for RQSpline model with 10 layers, [128,128] conditioners.

See whether swapping make_flow with a scan can solve this compilation problem. This could very well be issue associated with how distrax form the bijector, in that case, there won't be much I can do.

kazewong commented 1 year ago

Distrax is now deprecated, solved in #123