ecmwf-lab / ai-models

Apache License 2.0
353 stars 54 forks source link

issue with input data from cds for graphcast #10

Closed mixty1101 closed 10 months ago

mixty1101 commented 1 year ago

when i use " ai-models --input cds --date 20230110 --time 0000 graphcast " testing the graphcast from cds era5, I got these files in "/tmp/climetlab-user" and errors as follows. Is the model use the following .cache files to give the inference? And key='time' is related to what physical parameters (maybe the precipitation for accumulation value), should I revise download scripts the to run the inference?

files at /tmp/climetlab-user:

_- cache-2.db                                                                            grib-index-650de97dadfda5522ef6065aba6d53070b7f5474f2c88d3ecfae8603f179df6a.json
- cds-retriever-1a1d7637aaaf5616ccdf2c650bb81dd5b6613e9f032826fa57c364fbd61136c5.cache  grib-index-74269e34badef4add821af3244ea8d64cc621103c1d653d33b7896d707c62fcf.json  
- cds-retriever-30f1a29a4568ca8f99f88d3e20200b657d39a0b2a54534e9875953cc04808e8d.cache  grib-index-8a9615581ca2c8859f1fe76024ba7ace185fd86c85b22b916c2befd7cec92fae.json  
- cds-retriever-68098f31c2301f0a4c62ac25822c3a33eb28b44a8c8e22c17ba136f1067dff9e.cache  grib-index-a6fc434c8b189aa395892a5544257e762196fd7599dde47bf169ed1c9b301b12.json  
- cds-retriever-7576dbeeafe941786b728196693f1945d77a76cd427776d24eccb6dbbfa42c02.cache_

error messages:

2023-09-18 19:06:25,607 INFO Creating training dataset
2023-09-18 19:06:28,613 INFO Creating input data: 3 minutes.
2023-09-18 19:06:28,613 INFO Total time: 3 minutes 1 second.
Traceback (most recent call last):
  File "/home/user/anaconda3/envs/ai-models/bin/ai-models", line 8, in <module>
    sys.exit(main())
  File "/home/user/anaconda3/envs/ai-models/lib/python3.10/site-packages/ai_models/__main__.py", line 274, in main
    _main()
  File "/home/user/anaconda3/envs/ai-models/lib/python3.10/site-packages/ai_models/__main__.py", line 247, in _main
    model.run()
  File "/home/user/anaconda3/envs/ai-models/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 199, in run
    training_xarray, time_deltas = create_training_xarray(
  File "/home/user/anaconda3/envs/ai-models/lib/python3.10/site-packages/ai_models_graphcast/input.py", line 60, in create_training_xarray
    fields_sfc.to_xarray().rename(GRIB_TO_XARRAY_SFC).isel(number=0, surface=0)
  File "/home/user/anaconda3/envs/ai-models/lib/python3.10/site-packages/climetlab/readers/grib/xarray.py", line 106, in to_xarray
    result = xr.open_dataset(
  File "/home/user/anaconda3/envs/ai-models/lib/python3.10/site-packages/xarray/backends/api.py", line 570, in open_dataset
    backend_ds = backend.open_dataset(
  File "/home/user/anaconda3/envs/ai-models/lib/python3.10/site-packages/cfgrib/xarray_plugin.py", line 108, in open_dataset
    store = CfGribDataStore(
  File "/home/user/anaconda3/envs/ai-models/lib/python3.10/site-packages/cfgrib/xarray_plugin.py", line 40, in __init__
    self.ds = opener(filename, **backend_kwargs)
  File "/home/user/anaconda3/envs/ai-models/lib/python3.10/site-packages/cfgrib/dataset.py", line 750, in open_fieldset
    return open_from_index(filtered_index, read_keys, time_dims, extra_coords, **kwargs)
  File "/home/user/anaconda3/envs/ai-models/lib/python3.10/site-packages/cfgrib/dataset.py", line 726, in open_from_index
    dimensions, variables, attributes, encoding = build_dataset_components(
  File "/home/user/anaconda3/envs/ai-models/lib/python3.10/site-packages/cfgrib/dataset.py", line 680, in build_dataset_components
    dict_merge(variables, coord_vars)
  File "/home/user/anaconda3/envs/ai-models/lib/python3.10/site-packages/cfgrib/dataset.py", line 611, in dict_merge
    raise DatasetBuildError(
cfgrib.dataset.DatasetBuildError: key present and new value is different: key='time' value=Variable(dimensions=('time',), data=array([1673287200, 1673308800])) new_value=Variable(dimensions=('time',), data=array([1673244000, 1673287200]))
whu-dyf commented 12 months ago

Hello, have you resolved the issue yet? I'm experiencing the same problem.

b8raoult commented 11 months ago

Can you please try again? We have made some changes to the code that should fix that issue.

crlna16 commented 11 months ago

For me the issue persists with the latest version of the code:

Traceback (most recent call last):
  File "/home/k/k202141/.conda/envs/ai-models/bin/ai-models", line 8, in <module>
    sys.exit(main())
  File "/home/k/k202141/.conda/envs/ai-models/lib/python3.10/site-packages/ai_models/__main__.py", line 274, in main
    _main()
  File "/home/k/k202141/.conda/envs/ai-models/lib/python3.10/site-packages/ai_models/__main__.py", line 247, in _main
    model.run()
  File "/home/k/k202141/.conda/envs/ai-models/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 199, in run
    training_xarray, time_deltas = create_training_xarray(
  File "/home/k/k202141/.conda/envs/ai-models/lib/python3.10/site-packages/ai_models_graphcast/input.py", line 60, in create_training_xarray
    fields_sfc.to_xarray().rename(GRIB_TO_XARRAY_SFC).isel(number=0, surface=0)
  File "/home/k/k202141/.conda/envs/ai-models/lib/python3.10/site-packages/climetlab/readers/grib/xarray.py", line 106, in to_xarray
    result = xr.open_dataset(
  File "/home/k/k202141/.conda/envs/ai-models/lib/python3.10/site-packages/xarray/backends/api.py", line 570, in open_dataset
    backend_ds = backend.open_dataset(
  File "/home/k/k202141/.conda/envs/ai-models/lib/python3.10/site-packages/cfgrib/xarray_plugin.py", line 108, in open_dataset
    store = CfGribDataStore(
  File "/home/k/k202141/.conda/envs/ai-models/lib/python3.10/site-packages/cfgrib/xarray_plugin.py", line 40, in __init__
    self.ds = opener(filename, **backend_kwargs)
  File "/home/k/k202141/.conda/envs/ai-models/lib/python3.10/site-packages/cfgrib/dataset.py", line 750, in open_fieldset
    return open_from_index(filtered_index, read_keys, time_dims, extra_coords, **kwargs)
  File "/home/k/k202141/.conda/envs/ai-models/lib/python3.10/site-packages/cfgrib/dataset.py", line 726, in open_from_index
    dimensions, variables, attributes, encoding = build_dataset_components(
  File "/home/k/k202141/.conda/envs/ai-models/lib/python3.10/site-packages/cfgrib/dataset.py", line 680, in build_dataset_components
    dict_merge(variables, coord_vars)
  File "/home/k/k202141/.conda/envs/ai-models/lib/python3.10/site-packages/cfgrib/dataset.py", line 611, in dict_merge
    raise DatasetBuildError(
cfgrib.dataset.DatasetBuildError: key present and new value is different: key='time' value=Variable(dimensions=('time',), data=array([1696010400, 1696032000])) new_value=Variable(dimensions=('time',), data=array([1695967200, 1696010400]))
b8raoult commented 11 months ago

The call stack shows that you are still using the old code. The new code does not call cfgrib. Have a look at the same issue reported here: https://github.com/ecmwf-lab/ai-models-graphcast/issues/1

mixty1101 commented 11 months ago

Thanks a lot for your help @b8raoult! I followed the instruction in https://github.com/ecmwf-lab/ai-models-graphcast/issues/1 with the ai-model code 0.28 update from 0.25.

conda create -n ai-models028 python=3.10
conda activate ai-models028
conda install cudatoolkit==11.8.0 
pip install ai-models==0.2.8

git clone https://github.com/ecmwf-lab/ai-models-graphcast.git
cd ai-models-graphcast
pip3 install --upgrade -e .
pip3 install -r requirements-gpu.txt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

The above mentioned issue was solved. However, a small problem arrived as follows:

ai-models --input cds --date 20230110 --time 0000 graphcast
2023-10-08 18:31:16,385 INFO Writing results to graphcast.grib.
/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/gribapi/__init__.py:23: UserWarning: ecCodes 2.31.0 or higher is recommended. You are running ver
sion 2.30.0
  warnings.warn(
2023-10-08 18:31:16,677 INFO Model description: 
GraphCast model at 0.25deg resolution, with 13 pressure levels. This model is
trained on ERA5 data from 1979 to 2017, and fine-tuned on HRES-fc0 data from
2016 to 2021 and can be causally evaluated on 2022 and later years. This model
does not take `total_precipitation_6hr` as inputs and can make predictions in an
operational setting (i.e., initialised from HRES-fc0).

2023-10-08 18:31:16,677 INFO Model license: 
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You
may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ERA5 data, see README for attribution statement.

2023-10-08 18:31:16,678 INFO Loading params/GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.n
pz: 0.3 second.
2023-10-08 18:31:16,678 INFO Building model: 0.3 second.
2023-10-08 18:31:16,678 INFO Loading surface fields from CDS
2023-10-08 18:31:16,751 INFO Loading pressure fields from CDS
2023-10-08 18:31:28,367 INFO Creating forcing variables: 11 seconds.
2023-10-08 18:31:33,994 INFO Converting GRIB to xarray: 5 seconds.
2023-10-08 18:31:37,726 INFO Reindexing: 3 seconds.
2023-10-08 18:31:37,759 INFO Creating training data: 21 seconds.
2023-10-08 18:31:44,758 INFO Extracting input targets: 6 seconds.
2023-10-08 18:31:44,758 INFO Creating input data (total): 28 seconds.
2023-10-08 18:31:45,771 INFO Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
2023-10-08 18:31:45,772 INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or director
y
2023-10-08 18:32:38.298253: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:

  %pad.149 = bf16[3114720,8]{1,0} pad(bf16[3114720,4]{1,0} %constant.365, bf16[] %constant.768), padding=0_0x0_4, metadata={op_name="jit(<unnamed wrapped function>)/jit
(main)/while/body/remat/mesh2grid_gnn/_embed/mesh2grid_gnn/sequential/encoder_edges_mesh2grid_mlp/linear_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precis
ion=None preferred_element_type=bfloat16]" source_file="/share/project_master/ai-models-graphcast/ai_models_graphcast/model.py" source_line=168}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constan
t folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2023-10-08 18:32:43.824286: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 6.526109636s
Constant folding an instruction is taking > 1s:

  %pad.149 = bf16[3114720,8]{1,0} pad(bf16[3114720,4]{1,0} %constant.365, bf16[] %constant.768), padding=0_0x0_4, metadata={op_name="jit(<unnamed wrapped function>)/jit
(main)/while/body/remat/mesh2grid_gnn/_embed/mesh2grid_gnn/sequential/encoder_edges_mesh2grid_mlp/linear_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precis
ion=None preferred_element_type=bfloat16]" source_file="/share/project_master/ai-models-graphcast/ai_models_graphcast/model.py" source_line=168}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constan
t folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2023-10-08 18:32:46.097633: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 2s:

  %pad.1 = bf16[1618752,8]{1,0} pad(bf16[1618745,4]{1,0} %constant.374, bf16[] %constant.687), padding=0_7x0_4, metadata={op_name="jit(<unnamed wrapped function>)/jit(m
ain)/while/body/remat/grid2mesh_gnn/_embed/grid2mesh_gnn/sequential/encoder_edges_grid2mesh_mlp/linear_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precisio
n=None preferred_element_type=bfloat16]" source_file="/share/project_master/ai-models-graphcast/ai_models_graphcast/model.py" source_line=168}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constan
t folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2023-10-08 18:32:47.464672: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 3.367102627s
Constant folding an instruction is taking > 2s:

  %pad.1 = bf16[1618752,8]{1,0} pad(bf16[1618745,4]{1,0} %constant.374, bf16[] %constant.687), padding=0_7x0_4, metadata={op_name="jit(<unnamed wrapped function>)/jit(m
ain)/while/body/remat/grid2mesh_gnn/_embed/grid2mesh_gnn/sequential/encoder_edges_grid2mesh_mlp/linear_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precisio
n=None preferred_element_type=bfloat16]" source_file="/share/project_master/ai-models-graphcast/ai_models_graphcast/model.py" source_line=168}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constan
t folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2023-10-08 18:33:02,922 INFO Doing full rollout prediction in JAX: 1 minute 18 seconds.
2023-10-08 18:33:02,922 INFO Converting output xarray to GRIB and saving
2023-10-08 18:33:18,984 ERROR Error setting expver=None
2023-10-08 18:33:18,984 ERROR Invalid type of value when setting key 'expver'.
Traceback (most recent call last):
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/climetlab/readers/grib/codes.py", line 243, in set
    return eccodes.codes_set(self.handle, name, value)
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/gribapi/gribapi.py", line 2133, in grib_set
    raise GribInternalError(
gribapi.errors.GribInternalError: Invalid type of value when setting key 'expver'.
2023-10-08 18:33:19,137 ERROR Error setting expver=None
2023-10-08 18:33:19,137 ERROR Invalid type of value when setting key 'expver'.
Traceback (most recent call last):
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/climetlab/readers/grib/codes.py", line 243, in set
    return eccodes.codes_set(self.handle, name, value)
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/gribapi/gribapi.py", line 2133, in grib_set
    raise GribInternalError(
...
2023-10-08 18:35:09,466 ERROR Error setting expver=None
2023-10-08 18:35:09,466 ERROR Invalid type of value when setting key 'expver'.
Traceback (most recent call last):
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/climetlab/readers/grib/codes.py", line 243, in set
    return eccodes.codes_set(self.handle, name, value)
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/gribapi/gribapi.py", line 2133, in grib_set
    raise GribInternalError(
gribapi.errors.GribInternalError: Invalid type of value when setting key 'expver'.
2023-10-08 18:35:09,474 INFO Saving output data: 2 minutes 6 seconds.
2023-10-08 18:35:09,545 INFO Total time: 3 minutes 53 seconds.

It seems that the problem happens in key-value settings while writting to grib file. Is the problem related with lower version of ECCODES? I open the output grib with xrray, it shows as follows: u10, v10, tp are skipped

xr.open_dataset('graphcast.grib')
/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/gribapi/__init__.py:23: UserWarning: ecCodes 2.31.0 or higher is recommended. You are running ver
sion 2.30.0
  warnings.warn(
Ignoring index file 'graphcast.grib.923a8.idx' older than GRIB file
skipping variable: paramId==165 shortName='u10'
Traceback (most recent call last):
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/cfgrib/dataset.py", line 680, in build_dataset_components
    dict_merge(variables, coord_vars)
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/cfgrib/dataset.py", line 611, in dict_merge
    raise DatasetBuildError(
cfgrib.dataset.DatasetBuildError: key present and new value is different: key='heightAboveGround' value=Variable(dimensions=(), data=2.0) new_value=Variable(dimensions=
(), data=10.0)
skipping variable: paramId==166 shortName='v10'
Traceback (most recent call last):
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/cfgrib/dataset.py", line 680, in build_dataset_components
    dict_merge(variables, coord_vars)
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/cfgrib/dataset.py", line 611, in dict_merge
    raise DatasetBuildError(
cfgrib.dataset.DatasetBuildError: key present and new value is different: key='heightAboveGround' value=Variable(dimensions=(), data=2.0) new_value=Variable(dimensions=
(), data=10.0)
skipping variable: paramId==228 shortName='tp'
Traceback (most recent call last):
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/cfgrib/dataset.py", line 680, in build_dataset_components
    dict_merge(variables, coord_vars)
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/cfgrib/dataset.py", line 611, in dict_merge
    raise DatasetBuildError(
cfgrib.dataset.DatasetBuildError: key present and new value is different: key='time' value=Variable(dimensions=(), data=1673308800) new_value=Variable(dimensions=(), da
ta=1673287200)
<xarray.Dataset>
Dimensions:            (step: 40, latitude: 721, longitude: 1440,
                        isobaricInhPa: 13)
Coordinates:
    time               datetime64[ns] ...
  * step               (step) timedelta64[ns] 0 days 06:00:00 ... 10 days 00:...
    heightAboveGround  float64 ...
  * latitude           (latitude) float64 90.0 89.75 89.5 ... -89.5 -89.75 -90.0
  * longitude          (longitude) float64 0.0 0.25 0.5 ... 359.2 359.5 359.8
    valid_time         (step) datetime64[ns] ...
    meanSea            float64 ...
  * isobaricInhPa      (isobaricInhPa) float64 1e+03 925.0 850.0 ... 100.0 50.0
Data variables:
    t2m                (step, latitude, longitude) float32 ...
    msl                (step, latitude, longitude) float32 ...
    t                  (step, isobaricInhPa, latitude, longitude) float32 ...
    z                  (step, isobaricInhPa, latitude, longitude) float32 ...
    u                  (step, isobaricInhPa, latitude, longitude) float32 ...
    v                  (step, isobaricInhPa, latitude, longitude) float32 ...
    w                  (step, isobaricInhPa, latitude, longitude) float32 ...
    q                  (step, isobaricInhPa, latitude, longitude) float32 ...
Attributes:
    GRIB_edition:            2
    GRIB_centre:             ecmf
    GRIB_centreDescription:  European Centre for Medium-Range Weather Forecasts
    GRIB_subCentre:          0
    Conventions:             CF-1.7
    institution:             European Centre for Medium-Range Weather Forecasts
    history:                 2023-10-08T18:53 GRIB to CDM+CF via cfgrib-0.9.1...
b8raoult commented 11 months ago

Try the following --path 'output-{levtype}.grib'. That will create two GRIB files, one for single level parameters, one for upper air parameters. If that does not help, try --path 'output-{shortName}.grib' that should create a file per parameter.

mixty1101 commented 11 months ago

I tried ai-models --input cds --date 20230110 --time 0000 graphcast --path 'output-{shortName}-{step}-{levtype}-{level}.grib, The output grid names are like "output-q-240-pl-300.grib". But I still got the error :

Traceback (most recent call last):
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/climetlab/readers/grib/codes.py", line 243, in set
    return eccodes.codes_set(self.handle, name, value)
  File "/home/user/anaconda3/envs/ai-models028/lib/python3.10/site-packages/gribapi/gribapi.py", line 2133, in grib_set
    raise GribInternalError(
gribapi.errors.GribInternalError: Invalid type of value when setting key 'expver'.
2023-10-09 18:18:14,364 ERROR Error setting expver=None
2023-10-09 18:18:14,365 ERROR Invalid type of value when setting key 'expver'.
Traceback (most recent call last):

So I ADD output.to_netcdf('graphcast_test.nc') in output.py. Then I use the nc file as output. Is it proper for the inference? @b8raoult

def save_output_xarray(
    *,
    output,
    target_variables,
    write,
    all_fields,
    ordering,
    lead_time,
    hour_steps,
    lagged,
):
    LOG.info("Converting output xarray to GRIB and saving")

    ### output with netcdf format before converting to grib
    output.to_netcdf('graphcast_test.nc')

    output["total_precipitation_6hr"] = output.data_vars[
        "total_precipitation_6hr"
    ].cumsum(dim="time")
louisPoulain commented 11 months ago

Ihave the same issue as @mixty1101 (issue comment) when running pangu. The error did not produce a week or so ago. I tried with a path using {step} and with a path using no eccodes keys.

b8raoult commented 11 months ago

Can you give the exact command line that you run?

louisPoulain commented 11 months ago

The exact command was:

    while read -r line
    do
    read -r date time lead_time <<< "$line"
    dir_out = "/scratch/lpoulain/panguweather/pangu_d_${date}_t_${time}_{step}h.grib"
    ONNXRUNTIME=onnxruntime-gpu ai-models --assets ./panguweather/ --input cds --date "$date" --time "$time" --lead-time "$lead_time" --path dir_out panguweather
    done < input_params_pangu.txt

I also provided dir_out only as pangu.grib with no change

Update
I added the following to the command line expver 1 (while I don't really understand what this is doing) and got the error

    ECCODES ERROR   :  Wrong length for experimentVersionNumber. It has to be 4  
    2023-10-12 11:20:06,826 ERROR Error setting expver=1
b8raoult commented 11 months ago

Thanks for reporting that issue. I just created version 0.2.13 of ai-models that should fix the problem.

louisPoulain commented 11 months ago

Thanks, the issue is fixed

idharssi2020 commented 11 months ago

I've got ai-models graphcast running on my local machine. This is my conda setup for graphcast


cname=ai_models0211
conda create -n $cname -c conda-forge python=3.10 gpustat
conda activate $cname
conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit cuda-nvcc
python -m pip install --default-timeout=120 --no-cache-dir --upgrade \
        nvidia-cudnn-cu11==8.6.0.163 \
        "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

pip install xxhash
pip install ai-models==0.2.11

git clone https://github.com/ecmwf-lab/ai-models-graphcast.git
cd ai-models-graphcast
pip3 install --upgrade -e .
pip3 install -r requirements-gpu.txt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

To run graphcast


nvidia-smi
gpustat -a
pip freeze | grep cud

mkdir -p $CONDA_PREFIX/etc/conda/activate.d
echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
echo 'export LD_LIBRARY_PATH=$CONDA_PREFIX/lib/:$CUDNN_PATH/lib:$LD_LIBRARY_PATH' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh

export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
ai-models --input cds --date 20230110 --time 0000 --expver abcd graphcast --debug

Sample text output


2023-10-13 10:25:24,385 INFO Writing results to graphcast.grib.
2023-10-13 10:25:24,860 INFO Model description: 
GraphCast model at 0.25deg resolution, with 13 pressure levels. This model is
trained on ERA5 data from 1979 to 2017, and fine-tuned on HRES-fc0 data from
2016 to 2021 and can be causally evaluated on 2022 and later years. This model
does not take `total_precipitation_6hr` as inputs and can make predictions in an
operational setting (i.e., initialised from HRES-fc0).

2023-10-13 10:25:24,860 INFO Model license: 
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You
may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ERA5 data, see README for attribution statement.

2023-10-13 10:25:24,861 INFO Loading params/GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz: 0.5 second.
2023-10-13 10:25:24,861 INFO Building model: 0.5 second.
2023-10-13 10:25:24,861 INFO Loading surface fields from CDS
2023-10-13 10:25:25,363 INFO Loading pressure fields from CDS
2023-10-13 10:25:34,296 INFO Creating forcing variables: 8 seconds.
2023-10-13 10:25:40,560 INFO Converting GRIB to xarray: 6 seconds.
2023-10-13 10:25:44,562 INFO Reindexing: 4 seconds.
2023-10-13 10:25:44,600 INFO Creating training data: 19 seconds.
2023-10-13 10:25:52,185 INFO Extracting input targets: 7 seconds.
2023-10-13 10:25:52,187 INFO Creating input data (total): 27 seconds.
2023-10-13 10:25:52,371 INFO Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
2023-10-13 10:25:52,377 INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2023-10-13 10:29:07,703 INFO Doing full rollout prediction in JAX: 3 minutes 15 seconds.
2023-10-13 10:29:07,705 INFO Converting output xarray to GRIB and saving
2023-10-13 10:32:09,165 INFO Saving output data: 3 minutes 1 second.
2023-10-13 10:32:09,242 INFO Total time: 6 minutes 47 seconds.