kazewong / flowMC

Normalizing-flow enhanced sampling package for probabilistic inference in Jax
https://flowmc.readthedocs.io/en/main/
MIT License
200 stars 23 forks source link

170 lower precision training #171

Closed kazewong closed 5 months ago

kazewong commented 6 months ago

This PR added a small functionality for converting a model into different precision. Interestingly, the compiler would complain if bfloat16 or float16 is used, when the input data is also float16 or bfloat16. But if the data is float32, the code will execute fine. I think this is because some type promoting happened inside the code

This PR bump the jax version up as well.