google-deepmind / graphcast

Apache License 2.0
4.54k stars 572 forks source link

Inference on multiple TPU cores / GPUs #33

Closed ChrisAGBlake closed 9 months ago

ChrisAGBlake commented 10 months ago

I have the low level resolution model running locally in inference on a GPU (RTX 4090) and call also run the high resolution (37 pressure levels) for a couple of timesteps before running out of memory. Does anyone have any advice on parallelising across multiple GPUs or using a TPU v3-8 instance in GCP and utilising all TPU cores? I see there is the xarray_jax.pmap function which I assume can be used for this, but I'm not sure how to use it properly.

Dadoof commented 10 months ago

As near as I can tell, there is not presently a mechanism for dividing graphcast across GPUs. I have made some attempts via mpirun, or setting OMP threads, but I have personally not been able to achieve anything other than a 'serial' run. I'll follow this issue along with you, and see what the folks at Google have to say.

oubahe commented 9 months ago

I also encountered the same problem, and I'll follow this issue along with you.

ChrisAGBlake commented 9 months ago

In the paper they describe that they are able to generate a 10 day forecast on 1 TPU v4 device. My mistake was assuming that a v3 device would work. The v4 has 32GB memory available per chip, whereas the v3 has only 16GB. I didn't realise this discrepancy initially and assumed that the model would need to be distributed across multiple devices to run. I able to get it running on NVIDIA GPUs that have >= 24GB memory (I don't have access to v4 TPUs).

The xarray_jax.pmap function can be used for distributing across multiple devices but only at a batch level, not at a model level.