kazewong / flowMC

Normalizing-flow enhanced sampling package for probabilistic inference in Jax
https://flowmc.readthedocs.io/en/main/
MIT License
200 stars 23 forks source link

Put training loop into NF class #158

Closed kazewong closed 7 months ago

kazewong commented 7 months ago

Currently the training loop is constructed outside the normalizing flow model.

Since the function make_training_loop only has one argument, which is the optix optimizer, it is should be reasonably easy to incorporate the train_flow function into the normalizing flow model, so in the ideal case, the user can do something like nf_model.fit(data).