Open PhilipVinc opened 2 years ago
The main thing that JAX is missing to make this work is an implementation of the various collective operations in XLA that works across processes.
Two possibilities are mpi
, in which case the third-party package mpi4jax
may be of interest.
Another possibility might be to plug in something likegloo
into XLA to implement its collectives: https://github.com/facebookincubator/gloo
This would probably not be hard to do. Currently the collectives implemented in XLA/CPU are naive reference implementations.
(I am the author of mpi4jax
so I know that one pretty well 😄 )
Thanks for the explanation, I see the issue now.
The reason I'm asking that is mainly because 'pjit' and distributing along different axes along a mesh with pjit
are very interesting to me and I would like to play with it in our packages, but MPI limits us to only to something like data-parallelism.
Unfortunately I have many CPU-based users and that's why I need CPU support.
The other reason is that I'd like to support multiple-GPUs/CPUs per node (which is exactly what pjit/GlobalArray
does) but that would mean supporting mpi
primitives within pmap
which... I'm not sure how to do. Right now I have to force users to launch 1 jax process per GPU but that's particularly annoying in some HPC setups.
Yes, in essence you are asking to replace the collectives emitted by XLA internally with calls to mpi
or something similar. This is a bit different to mpi4jax
where you added a separate set of collective ops unknown to XLA on the side.
Right now, XLA:CPU emits calls to functions in a small helper library that implement collectives: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/service/cpu/cpu_runtime.h;drc=6eeb889576593a803bce51871b11fb2b27f8f2b3;l=174
Someone would need to teach XLA:CPU how to either call different runtime library functions, or to change those runtime library functions to call MPI (etc.) That would most likely need a bit of refactoring so multiple collective implementations can be plugged in.
Yes, in essence you are asking to replace the collectives emitted by XLA internally with calls to mpi or something similar. This is a bit different to mpi4jax where you added a separate set of collective ops unknown to XLA on the side.
Hmm, yes, that would be amazing, but I can imagine that would be quite a bit of work and I'm unsure if google's interested in that. Though surely academic groups working with HPC would be interested and benefit into it.
Someone would need to teach XLA:CPU how to either call different runtime library functions, or to change those runtime library functions to call MPI (etc.) That would most likely need a bit of refactoring so multiple collective implementations can be plugged in
Getting my hands dirty in XLA itself is a bit beyond the amount of time I have available now, unfortunately. Is there anything I can do to help you In the process/convince you that this is an useful path?
[/begin off topic]
However, something that might make me temporarily, slightly happier would also be a way to insert MPI custom calls into pmap
-ped functions. Right now we define C functions that respect the XLA calling convention and then specify how to encode those on cpu and gpu but I have no idea how to support pmap in this context. If you have any pointers I'd be happy to take them.
[/end off topic]
@hawkinsp did anything change on your end about this recently, or you're still not really planning on supporting this-?
@hawkinsp, are there any updates on this end? I would like to start using JAX on our CPU-based cluster.
Hello, I may be interested in taking this on by implementing gloo
-based collective ops, replacing the naive implementations.
How does this sound @PhilipVinc @hawkinsp ?
I will start with all-reduce and broadcast, in that order.
@jon-chuang well that sounds excellent, if you wanted to contribute that! I would agree: start with all-reduce, which is by itself enough for data-parallel training.
Here is the MVP target:
psum
on separate processes running on local with XLA CPU runtime; e2e test (specifically, multiprocess_cpu_test.py
, similar to gpu test).@PhilipVinc could you advise on the degree to which we should be able to perform a collective operation across both CPU and GPU (e.g. GPU+CPU offloading).
In this case:
CPU<->GPU
: MPI (incl. cross-process, same device)CPU<->CPU
: MPIGPU<->GPU
: NCCLThe way we can implement it is to e.g. for all-reduce:
CPU<->GPU
+ GPU<->GPU
(same host)), GPU<->GPU
or CPU<->CPU
(cross-host). I think that GPU<->GPU
should have better performance via NCCL?
See also: https://github.com/alpa-projects/alpa/issues/694
Note that I don't think that even torch.distributed
allows for hybrid cluster?
EDIT: gloo implements local reduction in CPU memory - see e.g. cuda_allreduce_ring
, cuda_allreduce_ring_chunked
. The latter leverages NCCL for same-host multi-GPU reduce/scatter.
@jon-chuang thank you for looking into this. It's something that would greatly benefit many people including me...
As per your question, If I understand your question correctly, you want to know what reduction operations must be implemented? In XLA, as of today, there exist only CPU-based reductions (so CPU to CPU) or GPU-based reductions (so GPU to GPU). That's because an XLA compiled executable can only run on one platform.
So you should not worry for CPU-GPU reductions and can always assume that the devices executing your distributed operation are homogeneous. At least, that assumption has worked very well for mpi4jax so far.
There might be plans to allow for hybrid computations (@hawkinsp will know for sure) but I'd leave that out of scope for a first implementation.
--
I think the only operation you need to implement is CPU-CPU reductions, possibly using MPI. GPU-GPU reductions are already implemented (using NCCL I think).
Actually, I got a hint that the new PJRT runtime can handle a mixed CPU<->GPU workload. Could you confirm @hawkinsp ?
@jon-chuang there are explorations in that direction, but nothing concrete at this time. It might also be done primarily at a layer above PJRT even if it happens.
I would not look into hybrid computations for an MVP.
Also, relatively relevant, how are you going to implement this?
Ideally this could be a plugin to XLA (do those even exist?) that depends on MPI.
Not so far off from the current compile step of mpi4jax
.
Forcing users to recompile the full jaxlib on HPC machines with finicky compilers is going to be a recipe for problems. Maybe good for a MVP, but In the long run it will be hard to switch users to it.
Maybe good for a MVP, but In the long run it will be hard to switch users to it.
As far as plugins go, CMIIW, XLA already has support for such dynamically-loaded runtime libraries, as they want to support user-side custom call lowering/dispatch.
I did consider baking into JAX/XLA, but the plugin-way seems neater, and can come bundled with JAX/XLA if deemed the most reasonable default.
I added a secret option jax_cpu_enable_gloo_collectives
in https://github.com/google/jax/commit/384e29e30d000f1e9c7d7d4a52eadb0ef8a8141a . This enables cross-process CPU support (needs jaxlib from head, built with xla from head).
Please note that:
a) there is no support for encryption: your data will travel unencrypted over the wire. This may or may not be acceptable depending on the application. b) the collectives are currently synchronous, so they won't be that fast, yet. c) collectives are only lightly tested so far.
I should add: it wouldn't be terribly hard to plug in MPI collectives here as well, if one wanted to do so. @PhilipVinc
(One implements: https://github.com/openxla/xla/blob/main/xla/service/cpu/collectives_interface.h essentially.)
@inailuig has recently (sucesfully) experimented with plugging MPI inside of a CPU device... He can say more than I can.
However, what we would love is for a way to write a sort of plug-in that can easily make use of MPI with jax native sharding, and in a way that can be relatively stable and not an hack...
A thing we might be able to do is to dlopen()
an MPI implementation and implement collectives on top of it. If done that way (dlopen()
) it's potentially something we could upstream. (I wouldn't want to require MPI at build time.)
@hawkinsp This is great news. In the summer I had a go at implementing a MPI plugin, inserting the mpi calls directly into the existing pjrt cpu client, see here. It works, but is really only at the proof of concept stage (all I needed was global allreduce). The new interface mentioned above should now make it a lot easier to implement the collectives in a pluggable way.
As of jax 0.4.27 released yesterday (and jaxlib 0.4.26) there is now (finally) support for cross-process communication using MPI, it can be used like this:
Download and compile MPIwrapper
git clone https://github.com/eschnett/MPIwrapper.git
cd MPIwrapper
mkdir build
cd build
cmake ../
make
and inititialize jax like this:
import os
os.environ['MPITRAMPOLINE_LIB'] = "/path/to/libmpiwrapper.so"
import jax
jax.config.update('jax_cpu_collectives_implementation', 'mpi')
jax.distributed.initialize()
# ...
The libmpiwrapper.so
can be found in the build
folder created above.
Unless I am mistaken, it is only possible to use the distributed backend (initialised with
jax.distributed.initialize
) with the GPU and TPU backends.However, I believe that Tensorflow, thus XLA should also support the CPU backend. Would it be possible to support it in Jax as well so that it will be possible to use it with
pjit
& co?