jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.07k stars 2.75k forks source link

Make automatic distributed initialization work with Open MPI 5.x, PMIx, and PRRTE #14576

Open EricHallahan opened 1 year ago

EricHallahan commented 1 year ago

Background

https://github.com/google/jax/pull/13929 introduced automatic JAX distributed initialization via the Open MPI Open Run-Time Environment (ORTE) layer and its orterun process launcher (also known by its many aliases mpirun, mpiexec, oshrun, shmemrun).

Upcoming Open MPI 5.x series releases do away with the previous ORTE infrastructure for one based around the PMIx standard via the OpenPMIx reference PMIx implementation and complimentary PMIx Reference Run-Time Environment (PRRTE); in Open MPI 5.x the mpirun/mpiexec launcher is simply a wrapper for the PRRTE prterun launcher.

PMIx and PRRTE has differing behavior to ORTE which makes the implementation introduced in https://github.com/google/jax/pull/13929 incompatible with Open MPI 5.x. With Open MPI 5.0 (now in its tenth release candidate) continuing to approach release, there seems to be value in preparing JAX for this change.

Considerations & Challenges

Continued compatibility with ORTE and orterun

The current implementation (as introduced in https://github.com/google/jax/pull/13929) is fully usable with Open MPI versions prior to 5.x, and it is important to maintain compatibility with these releases when introducing support for Open MPI 5.x. It is unclear to me whether it would be wiser to make the current implementation compatible with the PRRTE-based launcher, or to create a separate piece of code to handle it.

New behaviors

PMIx/PRRTE exposes relevant information differently than ORTE.

OMPI_VERSION=5.0.0rc10
OMPI_TOOL_NAME=mpirun
PRTE_LAUNCHED=1
PMIX_NAMESPACE=prterun-%{hostname}-%{num_job_id}@1
PMIX_RANK=0
PMIX_SERVER_URI41=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_URI4=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_URI3=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_URI2=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_URI21=prterun-%{hostname}-%{num_job_id}@0.0;tcp4://%{ip_v4_addr}:%{port}
PMIX_SERVER_TMPDIR=/tmp/prte.%{hostname}.%{uid}/dvm.%{num_job_id}
PMIX_SYSTEM_TMPDIR=/tmp
OMPI_COMM_WORLD_SIZE=1
OMPI_WORLD_SIZE=1
OMPI_MCA_num_procs=1
OMPI_COMM_WORLD_RANK=0
OMPI_COMM_WORLD_LOCAL_RANK=0
OMPI_COMM_WORLD_NODE_RANK=0
PMIX_HOSTNAME=%{hostname}
A selection of variables exposed to a process when launching with mpirun under Open MPI 5.x.
mjsML commented 1 year ago

@EricHallahan are you interested in implementing the support? cc; @nvcastet

EricHallahan commented 1 year ago

@EricHallahan are you interested in implementing the support?

Prior to filing this issue, I made a patch to the existing implementation to make it work for Open MPI 5.x. I am willing to contribute it, but the question remains as to how to maintain support for earlier versions (something I didn't consider for my personal use) as they are going to remain in use for many years to come.

nvcastet commented 1 year ago

@EricHallahan Thanks a lot for raising this issue and the thorough discussion! We detect current OMPI with ORTE via the presence of the env var OMPI_MCA_orte_hnp_uri: https://github.com/google/jax/blob/main/jax/_src/clusters/ompi_cluster.py#L28-L29 For OpenMPI with PRRTE you could contribute your patch as a new subclass of ClusterEnv, OMPI ORTE will not interfere since OMPI_MCA_orte_hnp_uri will not be defined with PRRTE. Does that make sense?

EricHallahan commented 1 year ago

That is certainly a valid option! I'll go ahead and try that.

PhilipVinc commented 1 month ago

@EricHallahan could you contribute your patch, or maybe put it out there? It would be useful for me as well.

PhilipVinc commented 1 month ago

(And do you know if there is a way to do the same for MPICH by any chance?)

nvcastet commented 1 month ago

You can use auto-detection via mpi4py for that. See https://github.com/google/jax/pull/20174

PhilipVinc commented 1 month ago

I know, but I don't want to. I get some errors with that...