ecmwf-lab / ai-models

Apache License 2.0
394 stars 65 forks source link

Run out of memory when forecast much longer steps #14

Open whu-dyf opened 1 year ago

whu-dyf commented 1 year ago

For ai-models-graphcast, it works fine when I predict only a few time steps. However, it fails with an "out of memory" error when I try to predict over a longer lead time, such as 10 days. I have 188 GB of memory for CPU or 24 GB for GPU. Is there any solution to avoid this issue? It appears that the memory used by the model is not released after completing each step.

Thanks for your reply in advance!

b8raoult commented 1 year ago

Yes, we are aware of that problem and are looking into it.

I-Dhar commented 1 year ago

Try using these commands before running graphcast. Explanation is at https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
whu-dyf commented 1 year ago

Try using these commands before running graphcast. Explanation is at https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_ALLOCATOR=platform

I have tried your solution, but it failed with the following message: 2023-10-26 10:49:40,924 INFO Download rate 3.8M/s
2023-10-26 10:49:40,924 INFO Download rate 3.8M/s 2023-10-26 10:49:49,065 INFO Creating forcing variables: 6 seconds. 2023-10-26 10:49:56,199 INFO Converting GRIB to xarray: 7 seconds. 2023-10-26 10:50:01,614 INFO Reindexing: 5 seconds. 2023-10-26 10:50:01,665 INFO Creating training data: 6 minutes 22 seconds. 2023-10-26 10:50:11,951 INFO Extracting input targets: 10 seconds. 2023-10-26 10:50:11,951 INFO Creating input data (total): 6 minutes 32 seconds. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1698288611.955721 7562 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created. 2023-10-26 10:50:12,075 INFO Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA 2023-10-26 10:50:12,075 INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory 2023-10-26 10:50:36.996835: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:

%pad.151 = bf16[3114720,8]{1,0} pad(bf16[3114720,4]{1,0} %constant.369, bf16[] %constant.771), padding=0_0x0_4, metadata={op_name="jit()/jit(main)/while/body/remat/mesh2grid_gnn/_embed/mesh2grid_gnn/sequential/encoder_edges_mesh2grid_mlp/linear_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=bfloat16]" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results. 2023-10-26 10:50:39.321182: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 3.32445162s Constant folding an instruction is taking > 1s:

%pad.151 = bf16[3114720,8]{1,0} pad(bf16[3114720,4]{1,0} %constant.369, bf16[] %constant.771), padding=0_0x0_4, metadata={op_name="jit()/jit(main)/while/body/remat/mesh2grid_gnn/_embed/mesh2grid_gnn/sequential/encoder_edges_mesh2grid_mlp/linear_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=bfloat16]" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results. W0000 00:00:1698288645.880896 7562 hlo_rematerialization.cc:2946] Can't reduce memory use below -17.35GiB (-18632099542 bytes) by rematerialization; only reduced to 52.66GiB (56544015396 bytes), down from 52.66GiB (56544015396 bytes) originally 2023-10-26 10:50:47.788508: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2644] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to allocate request for 5.72GiB (6146380800B) on device ordinal 0 BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 2.06GiB constant allocation: 159.29MiB maybe_live_out allocation: 35.12GiB preallocated temp allocation: 15.75GiB preallocated temp fragmentation: 616B (0.00%) total allocation: 53.09GiB Peak buffers: Buffer 1: Size: 8.91GiB Operator: op_name="jit()/jit(main)/while/body/remat/mesh2grid_gnn/_process/mesh2grid_gnn/_process_step/mesh2grid_gnn/concatenate[dimension=2]" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8 XLA Label: fusion Shape: bf16[3114720,1,1536]

Buffer 2:
    Size: 5.72GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/dynamic_update_slice" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[40,1,37,721,1440]
    ==========================

Buffer 3:
    Size: 5.72GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/dynamic_update_slice" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[40,1,37,721,1440]
    ==========================

Buffer 4:
    Size: 5.72GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/dynamic_update_slice" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[40,1,37,721,1440]
    ==========================

Buffer 5:
    Size: 5.72GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/dynamic_update_slice" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[40,1,37,721,1440]
    ==========================

Buffer 6:
    Size: 5.72GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/dynamic_update_slice" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[40,1,37,721,1440]
    ==========================

Buffer 7:
    Size: 5.72GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/dynamic_update_slice" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[40,1,37,721,1440]
    ==========================

Buffer 8:
    Size: 2.97GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/remat/mesh2grid_gnn/_embed/mesh2grid_gnn/sequential/encoder_edges_mesh2grid_mlp/linear_1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=bfloat16]" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: custom-call
    Shape: bf16[3114720,512]
    ==========================

Buffer 9:
    Size: 1.98GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/remat/grid2mesh_gnn/_process/grid2mesh_gnn/_process_step/grid2mesh_gnn/add" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[1038240,1,512]
    ==========================

Buffer 10:
    Size: 293.08MiB
    XLA Label: fusion
    Shape: f32[1,2,37,721,1440]
    ==========================

Buffer 11:
    Size: 293.08MiB
    XLA Label: fusion
    Shape: f32[1,2,37,721,1440]
    ==========================

Buffer 12:
    Size: 293.08MiB
    XLA Label: fusion
    Shape: f32[1,2,37,721,1440]
    ==========================

Buffer 13:
    Size: 293.08MiB
    Entry Parameter Subshape: f32[1,2,37,721,1440]
    ==========================

Buffer 14:
    Size: 293.08MiB
    Entry Parameter Subshape: f32[1,2,37,721,1440]
    ==========================

Buffer 15:
    Size: 293.08MiB
    Entry Parameter Subshape: f32[1,2,37,721,1440]
    ==========================

2023-10-26 10:50:47,794 INFO Doing full rollout prediction in JAX: 35 seconds. 2023-10-26 10:50:47,794 INFO Total time: 7 minutes 9 seconds. jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/home/dyf/anaconda3/envs/graphcast/bin/ai-models", line 8, in sys.exit(main()) File "/home/dyf/anaconda3/envs/graphcast/lib/python3.10/site-packages/ai_models/main.py", line 285, in main _main() File "/home/dyf/anaconda3/envs/graphcast/lib/python3.10/site-packages/ai_models/main.py", line 258, in _main model.run() File "/home/dyf/anaconda3/envs/graphcast/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 240, in run output = self.model( File "/home/dyf/anaconda3/envs/graphcast/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 117, in return lambda kw: fn(kw)[0] jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 5.72GiB (6146380800B) on device ordinal 0 BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 2.06GiB constant allocation: 159.29MiB maybe_live_out allocation: 35.12GiB preallocated temp allocation: 15.75GiB preallocated temp fragmentation: 616B (0.00%) total allocation: 53.09GiB Peak buffers: Buffer 1: Size: 8.91GiB Operator: op_name="jit()/jit(main)/while/body/remat/mesh2grid_gnn/_process/mesh2grid_gnn/_process_step/mesh2grid_gnn/concatenate[dimension=2]" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8 XLA Label: fusion Shape: bf16[3114720,1,1536]

Buffer 2:
    Size: 5.72GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/dynamic_update_slice" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[40,1,37,721,1440]
    ==========================

Buffer 3:
    Size: 5.72GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/dynamic_update_slice" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[40,1,37,721,1440]
    ==========================

Buffer 4:
    Size: 5.72GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/dynamic_update_slice" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[40,1,37,721,1440]
    ==========================

Buffer 5:
    Size: 5.72GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/dynamic_update_slice" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[40,1,37,721,1440]
    ==========================

Buffer 6:
    Size: 5.72GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/dynamic_update_slice" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[40,1,37,721,1440]
    ==========================

Buffer 7:
    Size: 5.72GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/dynamic_update_slice" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[40,1,37,721,1440]
    ==========================

Buffer 8:
    Size: 2.97GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/remat/mesh2grid_gnn/_embed/mesh2grid_gnn/sequential/encoder_edges_mesh2grid_mlp/linear_1/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=bfloat16]" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: custom-call
    Shape: bf16[3114720,512]
    ==========================

Buffer 9:
    Size: 1.98GiB
    Operator: op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/remat/grid2mesh_gnn/_process/grid2mesh_gnn/_process_step/grid2mesh_gnn/add" source_file="/home/dyf/anaconda3/envs/graphcast/bin/ai-models" source_line=8
    XLA Label: fusion
    Shape: f32[1038240,1,512]
    ==========================

Buffer 10:
    Size: 293.08MiB
    XLA Label: fusion
    Shape: f32[1,2,37,721,1440]
    ==========================

Buffer 11:
    Size: 293.08MiB
    XLA Label: fusion
    Shape: f32[1,2,37,721,1440]
    ==========================

Buffer 12:
    Size: 293.08MiB
    XLA Label: fusion
    Shape: f32[1,2,37,721,1440]
    ==========================

Buffer 13:
    Size: 293.08MiB
    Entry Parameter Subshape: f32[1,2,37,721,1440]
    ==========================

Buffer 14:
    Size: 293.08MiB
    Entry Parameter Subshape: f32[1,2,37,721,1440]
    ==========================

Buffer 15:
    Size: 293.08MiB
    Entry Parameter Subshape: f32[1,2,37,721,1440]
    ==========================

I0000 00:00:1698288647.995927 7562 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.

idharssi2020 commented 1 year ago

I'm using a Tesla V100-SXM2-32GB.

whu-dyf commented 1 year ago

I'm using a Tesla V100-SXM2-32GB. I have rented two RTX 4090 24GB GPUs from a cloud service, but unfortunately, they are not functioning properly. As a result, I am only able to make predictions around 3 to 5 times (forget the specific number).

Dadoof commented 1 year ago

Good day all, I have run this on AWS, and was only able to do so using their 'p5' instance. Anything less, and I got the memory issues noted above.

chengshenlian commented 11 months ago

Initially, I used a 3090 with 24GB, but the run failed immediately. It was not until I rented an A100 with 80GB that I found everything could operate normally.

Following the example command ai-models --input cds --date 20230110 --time 0000 graphcast, I successfully obtained a file named graphcast.grib.

Additionally, during the run, I observed that both memory and video memory usage were around 60GB (the Jax library directly utilizes about 3/4 of the video memory).

(ai) root@747c17acba6c:/# /opt/conda/envs/ai/bin/ai-models  --input cds --date 20230110 --time 0000 graphcast
2023-11-28 06:50:35,205 INFO Writing results to graphcast.grib.
/opt/conda/envs/ai/lib/python3.10/site-packages/ecmwflibs/__init__.py:81: UserWarning: /lib/x86_64-linux-gnu/libgobject-2.0.so.0: undefined symbol: ffi_type_uint32, version LIBFFI_BASE_7.0
  warnings.warn(str(e))
2023-11-28 06:50:35,531 INFO Model description: 
GraphCast model at 0.25deg resolution, with 13 pressure levels. This model is
trained on ERA5 data from 1979 to 2017, and fine-tuned on HRES-fc0 data from
2016 to 2021 and can be causally evaluated on 2022 and later years. This model
does not take `total_precipitation_6hr` as inputs and can make predictions in an
operational setting (i.e., initialised from HRES-fc0).

2023-11-28 06:50:35,531 INFO Model license: 
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You
may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ERA5 data, see README for attribution statement.

2023-11-28 06:50:35,531 INFO Loading params/GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz: 0.3 second.
2023-11-28 06:50:35,531 INFO Building model: 0.3 second.
2023-11-28 06:50:35,531 INFO Loading surface fields from CDS
2023-11-28 06:50:35,656 INFO Loading pressure fields from CDS
2023-11-28 06:50:48,418 INFO Creating forcing variables: 12 seconds.
2023-11-28 06:50:53,993 INFO Converting GRIB to xarray: 5 seconds.
2023-11-28 06:50:57,666 INFO Reindexing: 3 seconds.
2023-11-28 06:50:57,706 INFO Creating training data: 22 seconds.
2023-11-28 06:51:04,715 INFO Extracting input targets: 6 seconds.
2023-11-28 06:51:04,715 INFO Creating input data (total): 29 seconds.
2023-11-28 06:51:05,098 INFO Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
2023-11-28 06:51:05,102 INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2023-11-28 06:52:10.480192: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:

  %pad.149 = bf16[3114720,8]{1,0} pad(bf16[3114720,4]{1,0} %constant.365, bf16[] %constant.768), padding=0_0x0_4, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/remat/mesh2grid_gnn/_embed/mesh2grid_gnn/sequential/encoder_edges_mesh2grid_mlp/linear_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=bfloat16]" source_file="/opt/conda/envs/ai/bin/ai-models" source_line=8}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2023-11-28 06:52:18.177603: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 8.697529096s
Constant folding an instruction is taking > 1s:

  %pad.149 = bf16[3114720,8]{1,0} pad(bf16[3114720,4]{1,0} %constant.365, bf16[] %constant.768), padding=0_0x0_4, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/remat/mesh2grid_gnn/_embed/mesh2grid_gnn/sequential/encoder_edges_mesh2grid_mlp/linear_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=bfloat16]" source_file="/opt/conda/envs/ai/bin/ai-models" source_line=8}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2023-11-28 06:52:20.556910: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 2s:

  %pad.1 = bf16[1618752,8]{1,0} pad(bf16[1618745,4]{1,0} %constant.374, bf16[] %constant.687), padding=0_7x0_4, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/remat/grid2mesh_gnn/_embed/grid2mesh_gnn/sequential/encoder_edges_grid2mesh_mlp/linear_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=bfloat16]" source_file="/opt/conda/envs/ai/bin/ai-models" source_line=8}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2023-11-28 06:52:22.395348: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 3.838535293s
Constant folding an instruction is taking > 2s:

  %pad.1 = bf16[1618752,8]{1,0} pad(bf16[1618745,4]{1,0} %constant.374, bf16[] %constant.687), padding=0_7x0_4, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/remat/grid2mesh_gnn/_embed/grid2mesh_gnn/sequential/encoder_edges_grid2mesh_mlp/linear_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=bfloat16]" source_file="/opt/conda/envs/ai/bin/ai-models" source_line=8}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2023-11-28 06:52:37,380 INFO Doing full rollout prediction in JAX: 1 minute 32 seconds.
2023-11-28 06:52:37,380 INFO Converting output xarray to GRIB and saving
2023-11-28 06:54:53,203 INFO Saving output data: 2 minutes 15 seconds.
2023-11-28 06:54:53,276 INFO Total time: 4 minutes 19 seconds.
(ai) root@747c17acba6c:/# ls
graphcast.grib  params  sspaas-fs  sspaas-tmp  stats  test.py  tf-logs
(ai) root@747c17acba6c:/# du -lh graphcast.grib 
6.5G    graphcast.grib

image

TaoWei8138 commented 11 months ago

onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running FusedMatMul node. Name:'MatMul_With_Transpose_FusedMatMulAndScale' Status Message: /onnxruntime_src/onnxruntime/core/framework/bfc_arena.cc:376 void onnxruntime::BFCArena::AllocateRawInternal(size_t, bool, onnxruntime::Stream, bool, onnxruntime::WaitNotificationFn) Failed to allocate memory for requested buffer of size 1851310080

it also happens when running pangu, but after a few repeated attempts., it works fine

BigShuiTai commented 10 months ago

I have the same issue above, and I am using a Tesla M40-24GB GPU. Is there any solution?

Trek-on commented 10 months ago

I have the same issue above, and I am using a NVIDIA 2080ti-11GB GPU. Is there any solution? Is it possible to solve this issue by adjusting the batch_size of the input of the model,and How?

mchantry commented 10 months ago

Hello all, There is no easy solution for this, as it would require a refactor of the code. We will consider it a feature request for the future, but currently do not have any capacity to fulfil it. Thanks, Matthew

Trek-on commented 10 months ago

May I ask what is the minimum graphics card configuration and minimum memory requirement for running the model?

BigShuiTai commented 10 months ago

May I ask what is the minimum graphics card configuration and minimum memory requirement for running the model?

From my experience, the minimum GPU memory requirement for running the model (like Pangu-wather) is 12 to 16 GiB, and for RAM when using CPU only for inference is at least 16 GiB.

Trek-on commented 10 months ago

May I ask what is the minimum graphics card configuration and minimum memory requirement for running the model?

From my experience, the minimum GPU memory requirement for running the model (like Pangu-wather) is 12 to 16 GiB, and for RAM when using CPU only for inference is at least 16 GiB.

How about Graphcast?

BigShuiTai commented 10 months ago

May I ask what is the minimum graphics card configuration and minimum memory requirement for running the model?

From my experience, the minimum GPU memory requirement for running the model (like Pangu-wather) is 12 to 16 GiB, and for RAM when using CPU only for inference is at least 16 GiB.

How about Graphcast?

I'm still testing Graphcast that uses GFS analysis data as the input. It used at least 18 GiB of my GPU memory but OOM at last.

Trek-on commented 8 months ago

Hello all, There is no easy solution for this, as it would require a refactor of the code. We will consider it a feature request for the future, but currently do not have any capacity to fulfil it. Thanks, Matthew

May I ask what is the minimum graphics card configuration and minimum memory requirement for running the model?I am still bothered by this issue, Is that A100 necessary?

xionghan7427 commented 6 months ago

I am trying to run the models with fewer steps and rerun the models with the output grib file as the new input file, but got errors, it seems the output grib file cannot be used as input file for those models. Any idea on why it is like this?

Trek-on commented 6 months ago

May be your output file only contains the NWP data in 13 pressure levels, rather than 37 pressure levels, which depends on your params files of models.

---Original--- From: @.> Date: Thu, Apr 25, 2024 23:52 PM To: @.>; Cc: @.**@.>; Subject: Re: [ecmwf-lab/ai-models] Run out of memory when forecast much longersteps (Issue #14)

I am trying to run the models with fewer steps and rerun the models with the output grib file as the new input file, but got errors, it seems the output grib file cannot be used as input file for those models. Any idea on why it is like this?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

BigShuiTai commented 6 months ago

I have written codes to solved this issue by myself, and you can check my forked repository. Hope this can help you!

xionghan7427 commented 6 months ago

I have written codes to solved this issue by myself, and you can check my forked repository. Hope this can help you!

I found your forked repository of graphcast, but after I installed it and run the ai-models the inference process got kill at first step:

(ai-models) [@localhost ~]$ ai-models --date 20240425 --time 1200 --lead-time 384 --path graphcast/2024-04-25z12:00:00.grib graphcast 2024-04-26 17:32:58,444 INFO Writing results to graphcast/2024-04-25z12:00:00.grib. 2024-04-26 17:32:58,444 INFO Loading surface fields from MARS 2024-04-26 17:32:58,638 INFO Loading pressure fields from MARS 2024-04-26 17:32:59,081 INFO Writing step 0: 0.5 second. 2024-04-26 17:32:59,512 INFO Writing step 0: 0.4 second. 2024-04-26 17:32:59,768 INFO Model description: GraphCast model at 0.25deg resolution, with 13 pressure levels. This model is trained on ERA5 data from 1979 to 2017, and fine-tuned on HRES-fc0 data from 2016 to 2021 and can be causally evaluated on 2022 and later years. This model does not take total_precipitation_6hr as inputs and can make predictions in an operational setting (i.e., initialised from HRES-fc0).

2024-04-26 17:32:59,768 INFO Model license: The model weights are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You may obtain a copy of the License at: https://creativecommons.org/licenses/by-nc-sa/4.0/. The weights were trained on ERA5 data, see README for attribution statement.

2024-04-26 17:32:59,768 INFO Loading params/GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz: 0.3 second. 2024-04-26 17:32:59,768 INFO Building model: 0.3 second. 2024-04-26 17:33:12,513 INFO Creating forcing variables: 12 seconds. 2024-04-26 17:33:15,242 INFO Converting GRIB to xarray: 2 seconds. Killed

BigShuiTai commented 6 months ago

I have written codes to solved this issue by myself, and you can check my forked repository. Hope this can help you!

I found your forked repository of graphcast, but after I installed it and run the ai-models the inference process got kill at first step:

(ai-models) [@localhost ~]$ ai-models --date 20240425 --time 1200 --lead-time 384 --path graphcast/2024-04-25z12:00:00.grib graphcast 2024-04-26 17:32:58,444 INFO Writing results to graphcast/2024-04-25z12:00:00.grib. 2024-04-26 17:32:58,444 INFO Loading surface fields from MARS 2024-04-26 17:32:58,638 INFO Loading pressure fields from MARS 2024-04-26 17:32:59,081 INFO Writing step 0: 0.5 second. 2024-04-26 17:32:59,512 INFO Writing step 0: 0.4 second. 2024-04-26 17:32:59,768 INFO Model description: GraphCast model at 0.25deg resolution, with 13 pressure levels. This model is trained on ERA5 data from 1979 to 2017, and fine-tuned on HRES-fc0 data from 2016 to 2021 and can be causally evaluated on 2022 and later years. This model does not take total_precipitation_6hr as inputs and can make predictions in an operational setting (i.e., initialised from HRES-fc0).

2024-04-26 17:32:59,768 INFO Model license: The model weights are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You may obtain a copy of the License at: https://creativecommons.org/licenses/by-nc-sa/4.0/. The weights were trained on ERA5 data, see README for attribution statement.

2024-04-26 17:32:59,768 INFO Loading params/GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz: 0.3 second. 2024-04-26 17:32:59,768 INFO Building model: 0.3 second. 2024-04-26 17:33:12,513 INFO Creating forcing variables: 12 seconds. 2024-04-26 17:33:15,242 INFO Converting GRIB to xarray: 2 seconds. Killed

If you have GFS analysis data, you can run demo_grib2nc_gfs.py to process it, then run demo_inference.py to inference without ai-models library.

xionghan7427 commented 5 months ago

Change the code to run the GraphCast_small model solved this problem for me: 1) You need to first download the low resolution asset: params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz 2) Modify the ai_models_graphcast/ai_models_graphcast/model.py line 57-58: the download_files refer to the low resolution asset line 68: change to grid = [1.0, 1,0] 3) Modify the ai_models_graphcast/ai_models_graphcast/input.py line 60: change the code to .reshape(len(forcing_variables), len(dates), 181, 360)

Those changes reduce the resolution to 1.0 and thus requires much less memory

327850200 commented 5 months ago

Change the code to run the GraphCast_small model solved this problem for me:

  1. You need to first download the low resolution asset: params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz
  2. Modify the ai_models_graphcast/ai_models_graphcast/model.py line 57-58: the download_files refer to the low resolution asset line 68: change to grid = [1.0, 1,0]
  3. Modify the ai_models_graphcast/ai_models_graphcast/input.py line 60: change the code to .reshape(len(forcing_variables), len(dates), 181, 360)

Those changes reduce the resolution to 1.0 and thus requires much less memory

Can you provide the specific operation details? For example, how to download the data?

xionghan7427 commented 5 months ago

Hello all, There is no easy solution for this, as it would require a refactor of the code. We will consider it a feature request for the future, but currently do not have any capacity to fulfil it. Thanks, Matthew

For the Pangu model, manually create the InferenceSession and destroy it after each use inside the stepper solved my problem

xionghan7427 commented 2 months ago

Change the code to run the GraphCast_small model solved this problem for me:

  1. You need to first download the low resolution asset: params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz
  2. Modify the ai_models_graphcast/ai_models_graphcast/model.py line 57-58: the download_files refer to the low resolution asset line 68: change to grid = [1.0, 1,0]
  3. Modify the ai_models_graphcast/ai_models_graphcast/input.py line 60: change the code to .reshape(len(forcing_variables), len(dates), 181, 360)

Those changes reduce the resolution to 1.0 and thus requires much less memory

Can you provide the specific operation details? For example, how to download the data?

In model.py you can change the code lines 57 58 to: class GraphcastModel(Model): download_url = "https://storage.googleapis.com/dm_graphcast/{file}" expver = "dmgc"

# Download
download_files = [
    (
        # "params/GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 -"
        # " pressure levels 13 - mesh 2to6 - precipitation output only.npz"
     "params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 -"
        " pressure levels 13 - mesh 2to5 - precipitation input and output.npz"
    ),
    "stats/diffs_stddev_by_level.nc",
    "stats/mean_by_level.nc",
    "stats/stddev_by_level.nc",
]

...

you can manually download the file from the download_url above.

see-ann commented 2 months ago

Change the code to run the GraphCast_small model solved this problem for me:

  1. You need to first download the low resolution asset: params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz
  2. Modify the ai_models_graphcast/ai_models_graphcast/model.py line 57-58: the download_files refer to the low resolution asset line 68: change to grid = [1.0, 1,0]
  3. Modify the ai_models_graphcast/ai_models_graphcast/input.py line 60: change the code to .reshape(len(forcing_variables), len(dates), 181, 360)

Those changes reduce the resolution to 1.0 and thus requires much less memory

Can you provide the specific operation details? For example, how to download the data?

In model.py you can change the code lines 57 58 to: class GraphcastModel(Model): download_url = "https://storage.googleapis.com/dm_graphcast/{file}" expver = "dmgc"

# Download
download_files = [
    (
        # "params/GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 -"
        # " pressure levels 13 - mesh 2to6 - precipitation output only.npz"
     "params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 -"
        " pressure levels 13 - mesh 2to5 - precipitation input and output.npz"
    ),
    "stats/diffs_stddev_by_level.nc",
    "stats/mean_by_level.nc",
    "stats/stddev_by_level.nc",
]

...

you can manually download the file from the download_url above.

Once you make those change how do you actually run ai-models?

see-ann commented 2 months ago

Getting a segmentation fault 2024-08-14 21:24:09,142 INFO Loading params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz: 0.4 second. 2024-08-14 21:24:09,142 INFO Building model: 0.4 second. 2024-08-14 21:24:09,582 INFO Creating forcing variables: 0.4 second. 2024-08-14 21:24:10,085 INFO Converting GRIB to xarray: 0.5 second. 2024-08-14 21:24:10,398 INFO Reindexing: 0.3 second. 2024-08-14 21:24:10,401 INFO Creating training data: 1 second. 2024-08-14 21:24:11,068 INFO Extracting input targets: 0.6 second. 2024-08-14 21:24:11,069 INFO Creating input data (total): 1 second. Segmentation fault

xionghan7427 commented 2 months ago

Yes, you need Mars access.

Best, Han

On Wed, Aug 14, 2024 at 3:46 PM Sean Wang @.***> wrote:

So I've tried ai-models --assets ./graphcast_assets graphcast where ./graphcast_assets holds params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz but im running into the following: ecmwfapi.api.APIException: "ecmwf.API error 1: User ' has no access to services/mars"

Do you need access to mars?

— Reply to this email directly, view it on GitHub https://github.com/ecmwf-lab/ai-models/issues/14#issuecomment-2289846891, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHM32CGLCV67XWEMJQXUEVTZRO6Y3AVCNFSM6AAAAAA5Y3CJVKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEOBZHA2DMOBZGE . You are receiving this because you commented.Message ID: @.***>