jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.44k stars 2.8k forks source link

[RFE] Add support for distributed CPU-backend mode #11182

Open PhilipVinc opened 2 years ago

PhilipVinc commented 2 years ago

Unless I am mistaken, it is only possible to use the distributed backend (initialised with jax.distributed.initialize ) with the GPU and TPU backends.

However, I believe that Tensorflow, thus XLA should also support the CPU backend. Would it be possible to support it in Jax as well so that it will be possible to use it with pjit & co?

hawkinsp commented 2 years ago

The main thing that JAX is missing to make this work is an implementation of the various collective operations in XLA that works across processes.

Two possibilities are mpi, in which case the third-party package mpi4jax may be of interest.

Another possibility might be to plug in something likegloo into XLA to implement its collectives: https://github.com/facebookincubator/gloo This would probably not be hard to do. Currently the collectives implemented in XLA/CPU are naive reference implementations.

PhilipVinc commented 2 years ago

(I am the author of mpi4jax so I know that one pretty well 😄 )

Thanks for the explanation, I see the issue now.

The reason I'm asking that is mainly because 'pjit' and distributing along different axes along a mesh with pjit are very interesting to me and I would like to play with it in our packages, but MPI limits us to only to something like data-parallelism. Unfortunately I have many CPU-based users and that's why I need CPU support.

The other reason is that I'd like to support multiple-GPUs/CPUs per node (which is exactly what pjit/GlobalArray does) but that would mean supporting mpi primitives within pmap which... I'm not sure how to do. Right now I have to force users to launch 1 jax process per GPU but that's particularly annoying in some HPC setups.

hawkinsp commented 2 years ago

Yes, in essence you are asking to replace the collectives emitted by XLA internally with calls to mpi or something similar. This is a bit different to mpi4jax where you added a separate set of collective ops unknown to XLA on the side.

Right now, XLA:CPU emits calls to functions in a small helper library that implement collectives: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/service/cpu/cpu_runtime.h;drc=6eeb889576593a803bce51871b11fb2b27f8f2b3;l=174

Someone would need to teach XLA:CPU how to either call different runtime library functions, or to change those runtime library functions to call MPI (etc.) That would most likely need a bit of refactoring so multiple collective implementations can be plugged in.

PhilipVinc commented 2 years ago

Yes, in essence you are asking to replace the collectives emitted by XLA internally with calls to mpi or something similar. This is a bit different to mpi4jax where you added a separate set of collective ops unknown to XLA on the side.

Hmm, yes, that would be amazing, but I can imagine that would be quite a bit of work and I'm unsure if google's interested in that. Though surely academic groups working with HPC would be interested and benefit into it.

Someone would need to teach XLA:CPU how to either call different runtime library functions, or to change those runtime library functions to call MPI (etc.) That would most likely need a bit of refactoring so multiple collective implementations can be plugged in

Getting my hands dirty in XLA itself is a bit beyond the amount of time I have available now, unfortunately. Is there anything I can do to help you In the process/convince you that this is an useful path?

[/begin off topic] However, something that might make me temporarily, slightly happier would also be a way to insert MPI custom calls into pmap-ped functions. Right now we define C functions that respect the XLA calling convention and then specify how to encode those on cpu and gpu but I have no idea how to support pmap in this context. If you have any pointers I'd be happy to take them.
[/end off topic]

PhilipVinc commented 1 year ago

@hawkinsp did anything change on your end about this recently, or you're still not really planning on supporting this-?

alelovato commented 1 year ago

@hawkinsp, are there any updates on this end? I would like to start using JAX on our CPU-based cluster.

jon-chuang commented 1 year ago

Hello, I may be interested in taking this on by implementing gloo-based collective ops, replacing the naive implementations.

How does this sound @PhilipVinc @hawkinsp ?

jon-chuang commented 1 year ago

I will start with all-reduce and broadcast, in that order.

hawkinsp commented 1 year ago

@jon-chuang well that sounds excellent, if you wanted to contribute that! I would agree: start with all-reduce, which is by itself enough for data-parallel training.

jon-chuang commented 1 year ago

Here is the MVP target:

  1. Can perform psum on separate processes running on local with XLA CPU runtime; e2e test (specifically, multiprocess_cpu_test.py, similar to gpu test).
jon-chuang commented 1 year ago

@PhilipVinc could you advise on the degree to which we should be able to perform a collective operation across both CPU and GPU (e.g. GPU+CPU offloading).

In this case:

  1. CPU<->GPU: MPI (incl. cross-process, same device)
  2. CPU<->CPU: MPI
  3. GPU<->GPU: NCCL

The way we can implement it is to e.g. for all-reduce:

  1. do a local all-reduce first (CPU<->GPU + GPU<->GPU(same host)),
  2. then use either GPU<->GPU or CPU<->CPU (cross-host).

I think that GPU<->GPU should have better performance via NCCL?

See also: https://github.com/alpa-projects/alpa/issues/694

Note that I don't think that even torch.distributed allows for hybrid cluster?

EDIT: gloo implements local reduction in CPU memory - see e.g. cuda_allreduce_ring, cuda_allreduce_ring_chunked. The latter leverages NCCL for same-host multi-GPU reduce/scatter.

PhilipVinc commented 1 year ago

@jon-chuang thank you for looking into this. It's something that would greatly benefit many people including me...

As per your question, If I understand your question correctly, you want to know what reduction operations must be implemented? In XLA, as of today, there exist only CPU-based reductions (so CPU to CPU) or GPU-based reductions (so GPU to GPU). That's because an XLA compiled executable can only run on one platform.

So you should not worry for CPU-GPU reductions and can always assume that the devices executing your distributed operation are homogeneous. At least, that assumption has worked very well for mpi4jax so far.

There might be plans to allow for hybrid computations (@hawkinsp will know for sure) but I'd leave that out of scope for a first implementation.

--

I think the only operation you need to implement is CPU-CPU reductions, possibly using MPI. GPU-GPU reductions are already implemented (using NCCL I think).

jon-chuang commented 1 year ago

Actually, I got a hint that the new PJRT runtime can handle a mixed CPU<->GPU workload. Could you confirm @hawkinsp ?

hawkinsp commented 1 year ago

@jon-chuang there are explorations in that direction, but nothing concrete at this time. It might also be done primarily at a layer above PJRT even if it happens.

I would not look into hybrid computations for an MVP.

PhilipVinc commented 1 year ago

Also, relatively relevant, how are you going to implement this? Ideally this could be a plugin to XLA (do those even exist?) that depends on MPI. Not so far off from the current compile step of mpi4jax.

Forcing users to recompile the full jaxlib on HPC machines with finicky compilers is going to be a recipe for problems. Maybe good for a MVP, but In the long run it will be hard to switch users to it.

jon-chuang commented 1 year ago

Maybe good for a MVP, but In the long run it will be hard to switch users to it.

As far as plugins go, CMIIW, XLA already has support for such dynamically-loaded runtime libraries, as they want to support user-side custom call lowering/dispatch.

I did consider baking into JAX/XLA, but the plugin-way seems neater, and can come bundled with JAX/XLA if deemed the most reasonable default.

hawkinsp commented 11 months ago

I added a secret option jax_cpu_enable_gloo_collectives in https://github.com/google/jax/commit/384e29e30d000f1e9c7d7d4a52eadb0ef8a8141a . This enables cross-process CPU support (needs jaxlib from head, built with xla from head).

Please note that:

a) there is no support for encryption: your data will travel unencrypted over the wire. This may or may not be acceptable depending on the application. b) the collectives are currently synchronous, so they won't be that fast, yet. c) collectives are only lightly tested so far.

hawkinsp commented 11 months ago

I should add: it wouldn't be terribly hard to plug in MPI collectives here as well, if one wanted to do so. @PhilipVinc

(One implements: https://github.com/openxla/xla/blob/main/xla/service/cpu/collectives_interface.h essentially.)

PhilipVinc commented 11 months ago

@inailuig has recently (sucesfully) experimented with plugging MPI inside of a CPU device... He can say more than I can.

However, what we would love is for a way to write a sort of plug-in that can easily make use of MPI with jax native sharding, and in a way that can be relatively stable and not an hack...

hawkinsp commented 11 months ago

A thing we might be able to do is to dlopen() an MPI implementation and implement collectives on top of it. If done that way (dlopen()) it's potentially something we could upstream. (I wouldn't want to require MPI at build time.)

inailuig commented 11 months ago

@hawkinsp This is great news. In the summer I had a go at implementing a MPI plugin, inserting the mpi calls directly into the existing pjrt cpu client, see here. It works, but is really only at the proof of concept stage (all I needed was global allreduce). The new interface mentioned above should now make it a lot easier to implement the collectives in a pluggable way.

inailuig commented 6 months ago

As of jax 0.4.27 released yesterday (and jaxlib 0.4.26) there is now (finally) support for cross-process communication using MPI, it can be used like this:

Download and compile MPIwrapper

git clone https://github.com/eschnett/MPIwrapper.git
cd MPIwrapper
mkdir build
cd build
cmake ../
make

and inititialize jax like this:

import os
os.environ['MPITRAMPOLINE_LIB'] = "/path/to/libmpiwrapper.so"

import jax
jax.config.update('jax_cpu_collectives_implementation', 'mpi')
jax.distributed.initialize()

# ...

The libmpiwrapper.so can be found in the build folder created above.