google-deepmind / graphcast

Apache License 2.0
4.36k stars 537 forks source link

RESOURCE_EXHAUSTED #73

Closed ncubukcu closed 1 month ago

ncubukcu commented 1 month ago

Hi I run graphcast on aws 4GPUs 48vCPU 192Gb cpu mem and 64Gb GPU mem which seems plenty. Unfortunately I get the following error below. In fact, regardless of different GPU mem, I always get exact same issue (ran out of memory trying to allocate 5.96GiB ). AWS instance has the following GPU information: NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2

It seems to me that the code does not partition jobs on each GPU while it is running. CUDA and jax installations are : pip install nvidia-cuda-cupti-cu12==12.2.131 pip install nvidia-cuda-nvcc-cu12==12.2.140 pip install nvidia-cuda-nvrtc-cu12==12.2.140 pip install nvidia-cuda-runtime-cu12==12.2.140 pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

any idea why this is happening?
Thanks

last part of the run:

2024-05-14 23:55:39,233 INFO Loading params/GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz: 0.5 second. 2024-05-14 23:55:39,233 INFO Building model: 0.5 second. 2024-05-14 23:55:45,442 INFO Creating forcing variables: 6 seconds. 2024-05-14 23:55:48,562 INFO Converting GRIB to xarray: 3 seconds. 2024-05-14 23:55:49,691 INFO Reindexing: 1 second. 2024-05-14 23:55:49,699 INFO Creating training data: 10 seconds. 2024-05-14 23:55:51,585 INFO Extracting input targets: 1 second. 2024-05-14 23:55:51,585 INFO Creating input data (total): 12 seconds. 2024-05-14 23:55:55,280 INFO Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA 2024-05-14 23:55:55,282 INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory /opt/conda/lib/python3.10/site-packages/graphcast/autoregressive.py:202: FutureWarning: The return type of Dataset.dims will be changed to return a set of dimension names in future, in order to be more consistent with DataArray.dims. To access a mapping from dimension names to lengths, please use Dataset.sizes. scan_length = targets_template.dims['time'] /opt/conda/lib/python3.10/site-packages/graphcast/autoregressive.py:115: FutureWarning: The return type of Dataset.dims will be changed to return a set of dimension names in future, in order to be more consistent with DataArray.dims. To access a mapping from dimension names to lengths, please use Dataset.sizes. num_inputs = inputs.dims['time'] 2024-05-14 23:57:26.945643: W external/tsl/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.96GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available. 2024-05-14 23:57:26,980 INFO Doing full rollout prediction in JAX: 1 minute 35 seconds. 2024-05-14 23:57:26,980 INFO Total time: 1 minute 52 seconds. jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/opt/conda/bin/ai-models", line 8, in sys.exit(main()) File "/opt/conda/lib/python3.10/site-packages/ai_models/main.py", line 336, in main _main(sys.argv[1:]) File "/opt/conda/lib/python3.10/site-packages/ai_models/main.py", line 282, in _main run(vars(args), unknownargs) File "/opt/conda/lib/python3.10/site-packages/ai_models/main.py", line 309, in run model.run() File "/opt/conda/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 240, in run output = self.model( File "/opt/conda/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 114, in return lambda kw: fn(kw)[0] jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6395723776 bytes.

alvarosg commented 1 month ago

It seems to me that the code does not partition jobs on each GPU while it is running. CUDA and jax installations are

I don't think this partitioning is something that would happen by default unless the code tells JAX which strategy to use for partitioning, which is not how it is implemented. I would recommend to run the high resolution model on hardware with at least 32 GB of memory per device.

Not, sure if there would be a way to automate that in this case but I think would be a better question for the JAX/XLA github repo.