google-research / dinosaur

Apache License 2.0
204 stars 14 forks source link

Running Dinosaur across multiple GPU devices #45

Open sit23 opened 2 months ago

sit23 commented 2 months ago

I've been testing out dinosaur for use in my research group here at the University of Exeter , and I'm attempting to see how dinosaur scales across multiple GPU devices. I've been running a version of your Held-Suarez test case on an AWS instance with 4 T4 GPUs. However, I haven't been able to get the test to make use of more than one of the GPUs - when I monitor usage across the GPUs I see GPU 0 getting 100% usage, and barely anything else on the others: Screenshot 2024-07-26 at 14 29 53 I see that there are test functions with the sharding specified across multiple devices, but I can't see that in the main body of Dinosaur. I'm new to jax, so perhaps this is my lack of experience showing here. I'm much more used to normal Python and Fortran (I did a lot of development of the Isca modelling framework (https://github.com/ExeClim/Isca). Any help would be gratefully recieved. Thanks for releasing dinosaur into the wild - it's great to be able to use it.

shoyer commented 1 month ago

Hi @sit23, good to hear from you! (And sorry for the delay responding, I just saw this.)

I'll see if we can work-up a demo of sharding Dinosaur across multiple devices. But in brief, you need to explicitly set a JAX parallel mesh on the coordinates object, something like:

import jax
import dinosaur
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((2, 2, 1))
mesh = jax.sharding.Mesh(mesh_devices, ['z', 'x', 'y'])
coords = dinosaur.coordinate_systems.CoordinateSystem(
    horizontal=dinosaur.spherical_harmonic.Grid.with_wavenumbers(
        longitude_wavenumbers=max_wavenumber + 1,
        spherical_harmonics_impl=dinosaur.spherical_harmonic.RealSphericalHarmonicsWithZeroImag,
    ),
    vertical=dinosaur.sigma_coordinates.SigmaCoordinates.equidistant(layers),
    spmd_mesh=mesh,
)
sit23 commented 1 month ago

Thanks @shoyer - this is great. I've been trying your suggestion with the held-suarez test case, and I can see that the memory is now being allocated more evenly across the GPUs, which is great, but I've run into a few problems that I've been working through. One of them is that the stack method applied to the velocity components here https://github.com/google-research/dinosaur/blob/main/dinosaur/held_suarez.py#L122 leads to trouble with the vertical padding function in spherical_harmonics, as that makes the assertion that the dimension of the array is 3: https://github.com/google-research/dinosaur/blob/main/dinosaur/spherical_harmonic.py#L671 which it isn't when the velocity components are stacked. I think this doesn't matter when you don't have a device mesh as this assertion is only reached if the mesh is specified. The first workaround I had for that was to change the padding to be 4D if the input was 4D, and this works, but leads to further errors deeper down in _transform_einsum where 4D data is not dealt with: https://github.com/google-research/dinosaur/blob/main/dinosaur/spherical_harmonic.py#L373 The workaround I have so far is that you can just do the velocity tendencies in held_suarez for each velocity component seperately, but this is a bit slow. I tried modifying the _transform_einsum routines to cope with 4D input, but this seemed like a more significant undertaking.

Now that I'm computing the velocity tendencies separately I've got the integration to run, but I've now got a similar problem 4D arrays appearing with dinosaur.spherical_harmonic.vor_div_to_uv_nodal when I'm converting the output to xarray. I don't think this will be too difficult to get by. Hopefully I'll have a working version of this soon - I'll try and report back soon.

shoyer commented 1 month ago

One of them is that the stack method applied to the velocity components here https://github.com/google-research/dinosaur/blob/main/dinosaur/held_suarez.py#L122 leads to trouble with the vertical padding function in spherical_harmonics

So as you are finding out, it seems that we have not tested solving Held-Suarez in parallel yet :)

I think there is a simpler fix, using a tuple of 3D arrays instead of a single 4D array: https://github.com/google-research/dinosaur/pull/48