Closed ChrisAGBlake closed 9 months ago
As near as I can tell, there is not presently a mechanism for dividing graphcast across GPUs. I have made some attempts via mpirun, or setting OMP threads, but I have personally not been able to achieve anything other than a 'serial' run. I'll follow this issue along with you, and see what the folks at Google have to say.
I also encountered the same problem, and I'll follow this issue along with you.
In the paper they describe that they are able to generate a 10 day forecast on 1 TPU v4 device. My mistake was assuming that a v3 device would work. The v4 has 32GB memory available per chip, whereas the v3 has only 16GB. I didn't realise this discrepancy initially and assumed that the model would need to be distributed across multiple devices to run. I able to get it running on NVIDIA GPUs that have >= 24GB memory (I don't have access to v4 TPUs).
The xarray_jax.pmap function can be used for distributing across multiple devices but only at a batch level, not at a model level.
I have the low level resolution model running locally in inference on a GPU (RTX 4090) and call also run the high resolution (37 pressure levels) for a couple of timesteps before running out of memory. Does anyone have any advice on parallelising across multiple GPUs or using a TPU v3-8 instance in GCP and utilising all TPU cores? I see there is the xarray_jax.pmap function which I assume can be used for this, but I'm not sure how to use it properly.