facebookresearch / fmmax

Fourier modal method with Jax
MIT License
95 stars 9 forks source link

Batching with pmap #46

Open smartalecH opened 11 months ago

smartalecH commented 11 months ago

Currently, FMMAX performs a lot of broadcasting under the hood, such that it's easy to simulate a device over multiple wavelength points and/or multiple k-points (all of which are completely independent simulations and can be executed in an embarrassingly parallel fashion).

In some cases, the full cartesian product of simulations cannot fit on a single accelerator, and it would be nice to distribute this across multiple accelerators.

Jax has some functionality for this using pmap. It might be nice to set up an example that takes an arbitrary combination of wavelengths and k-points and distributes the computation across all available accelerators. There are a lot of things to consider here, of course. For one, the eigendecomposition actually dispatches back to the host. So if all the devices perform this same dispatch, this could quickly become a bottleneck. Also, there are some limitations with pmap semantics when the number of parallel jobs is not an integer number of local accelerators, or if the accelerators live on different nodes.

smartalecH commented 10 months ago

A simple example comparing vmap to pmap would be useful here.