eelregit / pmwd

Differentiable Cosmological Forward Model
BSD 3-Clause "New" or "Revised" License
70 stars 19 forks source link

Multi-host distribution #3

Open EiffL opened 2 years ago

EiffL commented 2 years ago

In fantastic news, after over a year of waiting and checking every few months if it was working yet, it looks like finally it's possible to instantiate a distributed XLA runtime in Jax, which means.... Native access to NCCL collectives and composable parallelisation with pmap and xmap!!!

Demo of how to allocate 16 GPUs accross 4 nodes on Perlmutter here: https://github.com/EiffL/jax-gpu-cluster-demo

I'll be testing these things out and documenting my finding in this issue. Maybe won't be directly useful at first but at some point down the line we want to be able to run very large sims easily.

EiffL commented 2 years ago

ok.... well.... either it's black magic, or there is something I don't understand, but in any case my mind is blown....

mesh_shape = (2,) # On 2 GPUs
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = maps.Mesh(devices, ('x'))

parallel_za = pjit(
  lambda x: pm.generate_za([128,128,128], x, cosmo, dyn_conf, stat_conf).dm,
  in_axis_resources=PartitionSpec('x', None, None),
  out_axis_resources=PartitionSpec('x', None))

with maps.mesh(mesh.devices, mesh.axis_names):
 data = parallel_za(init_cond)

appears to be all it takes to distribute accross multiple devices.... But it returns the correct result...

I'm very puzzled by this.... in order to perform this operation it has to do a bunch of things like performing an fft over the distributed initial_cond field (which is split in 2 accross the first dimension), which I can imagine, but then it needs to compute the displacement over of 2 batches of particules, I'm really not sure how particules in process 2 get to know about the density stored by process 1, unless... it internally "undistribute" the data at some point which would defeat the purpose.... or it has to be sufficiently smart to devise a communication strategy to retrieve needed data....

EiffL commented 1 year ago

@eelregit Making progress on this ^^ jaxdecomp is now able to do forward and backward FFTs https://github.com/DifferentiableUniverseInitiative/jaxDecomp

Still have to add a few things, but really not far away from being usable as part of pmwd

EiffL commented 1 year ago

Ok, I've added halo exchange and cleaned up the interface. Also added gradients of these operations. You can also select which backend you want to use, MPI, NCCL, or NVSHMEM. As far as I can tell, this should be strictly superior to the cufftMP library, although now that I know how to do these bindings, it would be even easier to use cufftMP.

Here is how you do a 3D FFT distributed on many GPUs with the current version of the API:

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

import jax
import jax.numpy as jnp
import jaxdecomp

# Initialise the library, and optionally selects a communication backend (defaults to NCCL)
jaxdecomp.init()
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)

# Setup a processor mesh (should be same size as "size")
pdims= [2,4]
global_shape=[1024,1024,1024]

# Initialize an array with the expected gobal size
array = jax.random.normal(shape=[1024//pdims[1], 
                                 1024//pdims[0], 
                                 1024], 
            key=jax.random.PRNGKey(rank)).astype('complex64')

# Forward FFT, note that the output FFT is transposed
karray = jaxdecomp.pfft3d(array, 
                global_shape=global_shape, pdims=pdims)

# Reverse FFT
recarray = jaxdecomp.ipfft3d(karray, 
        global_shape=global_shape, pdims=pdims)

# Add halo regions to our array
padded_array = jnp.pad(array, [(32,32),(32,32),(32,32)])
# Perform a halo exchange
padded_array = jaxdecomp.halo_exchange(padded_array,
                                       halo_extents=(32,32,32),
                                       halo_periods=(True,True,True),
                                       pdims=pdims,
                                       global_shape=global_shape)

Compiling is unfortunately not 100% trivial because it depends a lot on the local environment of the cluster, so I haven't managed to fully automatize it yet....

EiffL commented 1 year ago

One 1024^3 FFT on 4 V100 GPUs... 0.5 ms :rofl:

EiffL commented 1 year ago

Annnnd 50ms for a 2048^3 FFT on 16 V100 GPUs on 2 nodes... :exploding_head: (also tagging @modichirag )

eelregit commented 1 year ago

Thanks! This is a lot of progress.

One 1024^3 FFT on 4 V100 GPUs... 0.5 ms rofl

Wow. But doesn't 0.5 ms sound too fast, like faster than the memory bandwidth on 1 GPU?

Annnnd 50ms for a 2048^3 FFT on 16 V100 GPUs on 2 nodes

Should we expect weak scaling here, in which case (a bit more than) 2x the 1024^3 timing?

EiffL commented 1 year ago

0.5ms does sound really fast, but the result of the FFT seems to be correct, so.... maybe?

I'm not 100% sure what scaling we should expect, as a function of message size it's possible that the cost is not the same, as the backend might switch between different strategies.

Also, interesting note, this is using the NCCL backend, if I use the MPI backend on this setting and hardware I get 6s. I guess it will be very hardware and problem dependent, but that's what's nice with cuDecomp it includes an autotuning tool, that will find the best distribution strategy and backend for given hardware and problem size (which I havent interfaced in jax yet, but is there at the C++ level)

eelregit commented 1 year ago

Okay, 50ms seems to be comparable to what cuFFTMp showed in https://developer.nvidia.com/blog/multinode-multi-gpu-using-nvidia-cufftmp-ffts-at-scale/

For me one 1024^3 FFT on 1 A100 seems to take more than 30ms. So... are you sure that 0.5ms on 4 V100 is not underestimating?

EiffL commented 1 year ago

oups ^^' you are right, I didnt include a block until ready...

New timings on V100s:

This is probably a lot more reasonable

eelregit commented 1 year ago

Still very promising. I wonder if the difference in performance and scaling is mainly from the hardware (nvlink, nvswitch etc)?

wendazhou commented 1 year ago

Nvidia claims that they can achieve close to perfect weak scaling with cuFFTmp in the 2^30 elements / GPU range (up to about 4k GPU), but I know that that library leans heavily on nvshmem communication to achieve optimal overlapping. nvidia claims that their cluster can do 2048^3 in ~100ms on 16 GPUs (albeit A100), so it definitely might be worth looking into using that library directly / setting up the correct hardware config for cuFFTmp and cuDECOMP.

EiffL commented 1 year ago

Yep, it should be trivial to add an op for cufftmp in jaxdecomp as it's already part of the nvhpc SDK, so no need to compile an external library :-) I'm traveling this week, but if I catch a little quiet time I'll add the option to use cufftMP. Unless you want to have a go at it @wendazhou ;-)

One thing though that I thought about, reading the documentation it looks like nvshmem memory needs to be allocated in a particular fashion, which is different from a standard cuda device memory. That means we can't directly use the input/output buffers allocated by XLA, and that the op will need to allocate its own buffers using nvshmem, that will kind of double the memory needed to run the FFT.

In my current implementation, I do the transform in-place within the input/buffer allocated by XLA. I also need a workspace buffer, of size determined by cudecomp (I think something like twice the size of the input array) , also allocated by XLA.

wendazhou commented 1 year ago

I'm also travelling for NeurIPS this week, probably won't have time to look at it. For the memory, I don't think that input / output buffers need to be allocated using special functions, only the scratch memory itself, which cuFFTmp handles internally (but this indeed requires linear memory in the transform size, see doc).

I think the main work in addition to plumbing everything together will be to figure out how to describe the data layout correctly doc

eelregit commented 1 year ago

@hai4john

EiffL commented 6 months ago

@eelregit ... ok, took about a year, but it's now working nicely with the latest version of JAX thanks to the heroic efforts of my collaborator @ASKabalan :-)

We have updated jaxDecomp https://github.com/DifferentiableUniverseInitiative/jaxDecomp to be compatible with the native JAX distribution, and we have a rough prototype of a distributed LPT demo here: https://github.com/DifferentiableUniverseInitiative/JaxPM/issues/19#issuecomment-2103839585

We have been able to scale it to 24 GPUs (the max I can allocated at Flatiron), and to give you an idea, it executes an LPT simulation (no PM iteration) on a 2048^3 mesh in 4.7s. We haven't yet carefully profiled the execution, it's not impossible we could do even better, but at least it means we can do reasonable cosmological volumes in a matter of seconds:

image (17)

With @ASKabalan we are probably going to integrate this in JaxPM, as a way to prototype user APIs compatible with distribution (it's not completely trivial, we don't want to bother users with a lot of stuff if they are running on a single GPU setting). But I still have in mind to push distribution to pmwd, so wanted to check-in with you to see what you think and if you have another path to distribution in mind already.

eelregit commented 6 months ago

Thanks @EiffL for letting me know. Let's find a time (maybe next week) to chat via email?

EiffL commented 4 months ago

@eelregit here is a minimal demo of LPT implemented using jaxdecomp: https://github.com/DifferentiableUniverseInitiative/jaxDecomp/blob/main/examples/lpt_nbody_demo.py

EiffL commented 4 months ago

@eelregit timings https://flanusse.net/talks/Split2024/#/15/0/1