Open kazewong opened 2 years ago
@kazewong Is there any though into parallelization over multi-node CPU resources? Such as through MPI?
Woops, I see you have an issue for that already https://github.com/kazewong/flowMC/issues/61
We are facing a similar issue for population inference of BBH on GWKokab for more than 100 events. The cluster resources get exhausted for a large number of events and lesser events are not giving me converging chains with heavy models. This feature is much needed and will alleviate the scalability of population inference of compact binaries using GWs data.
Currently the code runs on one device, which doesn't allow scaling to larger computational network such as TPU pods.
Parallelizing over local sampler should be relatively simple, since that does not required communication between devices. Note that if single evaluation of the likelihood demands more RAM than what's available on the chips (TPUv4 has 8GB RAM per core, gradient of functions may cause problem), the computation may need to be shard to multiple device, but that should be taken care separately.
Evaluation of global sampler should be similar to local sampler.
Training the normalizing flow requires collecting data from multiple devices and updating weights in a somewhat sync version. Have a look of
pmap
to see how to deal with that https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html