kazewong / flowMC

Normalizing-flow enhanced sampling package for probabilistic inference in Jax
https://flowmc.readthedocs.io/en/main/
MIT License
199 stars 23 forks source link

Add parallelization over multiple devices #39

Open kazewong opened 2 years ago

kazewong commented 2 years ago

Currently the code runs on one device, which doesn't allow scaling to larger computational network such as TPU pods.

Parallelizing over local sampler should be relatively simple, since that does not required communication between devices. Note that if single evaluation of the likelihood demands more RAM than what's available on the chips (TPUv4 has 8GB RAM per core, gradient of functions may cause problem), the computation may need to be shard to multiple device, but that should be taken care separately.

Evaluation of global sampler should be similar to local sampler.

Training the normalizing flow requires collecting data from multiple devices and updating weights in a somewhat sync version. Have a look of pmap to see how to deal with that https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html

kazewong commented 2 years ago

Also see this https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html

https://colab.research.google.com/github/marcvanzee/flax/blob/pjit-example/examples/siren/siren.ipynb

ahnitz commented 1 year ago

@kazewong Is there any though into parallelization over multi-node CPU resources? Such as through MPI?

ahnitz commented 1 year ago

Woops, I see you have an issue for that already https://github.com/kazewong/flowMC/issues/61

zeeshan5885 commented 1 month ago

We are facing a similar issue for population inference of BBH on GWKokab for more than 100 events. The cluster resources get exhausted for a large number of events and lesser events are not giving me converging chains with heavy models. This feature is much needed and will alleviate the scalability of population inference of compact binaries using GWs data.