ins-amu / vbjax

A nascent Jax-based package for virtual brain modeling.
Apache License 2.0
7 stars 2 forks source link

Test Jax parallel APIs #19

Closed maedoc closed 1 year ago

maedoc commented 1 year ago

options

Jax' built in option seems to automatically handle MPI so it might be the better starting point.

maedoc commented 1 year ago

After some quick tests with Jax' built in distributed functions, it seems like a simpler approach from a development standpoint: for single machine cases, no MPI required, just set

import os
os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=2"

before importing Jax, then jax.pmap will just distribute work over available devices, in a single process context.

For multiprocess context, some more work is required. The two main cases would be multi-node CPU & GPU. For interactive use, we could use Dask/IPyParallel to run wrap the MPI-like lockstep execution that Jax parallel primitives require.

maedoc commented 1 year ago

There's interesting work going on in Jax v0.4 https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html and https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html that could push for using built in primitives.

"automatic parallelization" 🤤.. 🤕

maedoc commented 1 year ago

I couldn't get the jax parallel API to work, and the distributed systems that we'll have to work with are MPI-based, so expertise is mainly in MPI.