ecmwf-lab / ai-models-graphcast

Apache License 2.0
56 stars 18 forks source link

Jax OOM error when running graphcast #18

Open cguiang opened 2 months ago

cguiang commented 2 months ago

Hi all,

I'm getting the following error when running the 0.25-degree Graphcast operational model: (omitting full traceback for brevity)

  File "/home/user/.local/bin/ai-models", line 8, in <module>
    sys.exit(main())
  File "/home/user/.local/lib/python3.10/site-packages/ai_models/__main__.py", line 358, in main
    _main(sys.argv[1:])
  File "/home/user/.local/lib/python3.10/site-packages/ai_models/__main__.py", line 306, in _main
    run(vars(args), unknownargs)
  File "/home/user/.local/lib/python3.10/site-packages/ai_models/__main__.py", line 331, in run
    model.run()
  File "/home/user/.local/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 240, in run
    output = self.model(
  File "/home/user/.local/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 114, in <lambda>
    return lambda **kw: fn(**kw)[0]
  File "/home/user/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
...
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 2.99GiB (3206250496B) on device ordinal 0

I've already tried the following:

  1. Disabling preallocation by setting XLA_PYTHON_CLIENT_PREALLOCATE=false
  2. Setting XLA_PYTHON_CLIENT_MEM_FRACTION to .20, .50 (and other values < .75) with XLA_PYTHON_CLIENT_ALLOCATOR unset
  3. Setting XLA_PYTHON_CLIENT_ALLOCATOR=platform

System info: (I'm only including the first one of four GPUs):

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf     |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:1B.0 Off |                    0 |
| N/A   36C    P0              25W /  70W |      2MiB / 15360MiB |      0%  Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

Python packages:

$ pip list installed | egrep 'graphcast|jax|cudnn'
ai-models-graphcast          0.0.7
graphcast                    0.1
jax                          0.4.25
jaxlib                       0.4.25+cuda12.cudnn89
nvidia-cudnn-cu12            8.9.7.29

I run "watch -n 2 nvidia-smi" during execution, and the output does indicate that sufficient memory is available on the GPU, (which is not the display GPU for what it's worth.) Any ideas on what I'm doing incorrectly?

Thanks in advance!

cguiang commented 2 months ago

I was able to get a more detailed error after running on a node equipped with a larger-memory GPU. It does look like the model is too big to run on our current instance.

Consider this closed for now.