Closed maedoc closed 1 year ago
After some quick tests with Jax' built in distributed functions, it seems like a simpler approach from a development standpoint: for single machine cases, no MPI required, just set
import os
os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=2"
before importing Jax, then jax.pmap
will just distribute work over available devices, in a single process context.
For multiprocess context, some more work is required. The two main cases would be multi-node CPU & GPU. For interactive use, we could use Dask/IPyParallel to run wrap the MPI-like lockstep execution that Jax parallel primitives require.
There's interesting work going on in Jax v0.4 https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html and https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html that could push for using built in primitives.
"automatic parallelization" 🤤.. 🤕
I couldn't get the jax parallel API to work, and the distributed systems that we'll have to work with are MPI-based, so expertise is mainly in MPI.
options
Jax' built in option seems to automatically handle MPI so it might be the better starting point.