NVIDIA / earth2mip

Earth-2 Model Intercomparison Project (MIP) is a python framework that enables climate researchers and scientists to inter-compare AI models for weather and climate.
https://nvidia.github.io/earth2mip/
Apache License 2.0
183 stars 40 forks source link

🐛[BUG]: Graphcast Operational Input/Output Channel Mismatch #156

Closed pgarg7 closed 7 months ago

pgarg7 commented 8 months ago

Version

source-main

On which installation method(s) does this occur?

Source, Pip

Describe the issue

I tried running graphcast_operational to obtain deterministic scores using inference_medium_range. However, I observed that there is a shape mismatch between input and output variables of the model as there is an additional variable "total precipitation 6 hour" (tp06) in the output channels of the model. This mismatch results into error in running score_deterministic module for graphcast_operational.

Code example

import datetime from earth2mip.networks import get_model from earth2mip.initial_conditions import cds from earth2mip.inference_ensemble import run_basic_inference from earth2mip.inference_medium_range import score_deterministic import numpy as np time_loop = get_model("e2mip://graphcast_operational", device="cuda:0") data_source = cds.DataSource(time_loop.in_channel_names)

ds = run_basic_inference(time_loop, n=5, data_source=data_source, time=datetime.datetime(2018, 1, 1))

scores = score_deterministic(time_loop,data_source=data_source,n=3,initial_times=[datetime.datetime(2018, 1, 1)],time_mean=np.zeros((len(time_loop.in_channel_names), 721, 1440))) print(scores)

Error

Traceback (most recent call last): File "/workspace/graphcast_test.py", line 14, in scores = score_deterministic(time_loop,data_source=data_source,n=3,initial_times=[datetime.datetime(2018, 1, 1)],time_mean=np.zeros((len(time_loop.in_channel_names), 721, 1440))) File "/usr/local/lib/python3.10/dist-packages/earth2mip/inference_medium_range.py", line 219, in score_deterministic save_scores( File "/usr/local/lib/python3.10/dist-packages/earth2mip/inference_medium_range.py", line 289, in save_scores run_forecast( File "/usr/local/lib/python3.10/dist-packages/earth2mip/inference_medium_range.py", line 122, in run_forecast channels = [ File "/usr/local/lib/python3.10/dist-packages/earth2mip/inference_medium_range.py", line 123, in data_source.channel_names.index(name) for name in model.out_channel_names ValueError: 'tp06' is not in list

Environment details

OS platform and distribution: Ubuntu 22.04.3 LTS (x86x64)
PyTorch version: 2.1.0a0+32f93b1
Python version:3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime)
CUDA/cuDNN version:12.2.140
GPU models and configuration:NVIDIA H100 80GB HBM3

Modulus version:23.11
How Modulus is used (Docker image/ bare-metal installation):Docker from nvcr.io
(If using Modulus Docker image): Exact docker-run command: I am running on Eos with the --container-image argument to provide details of Nvidia's modulus docker.
nbren12 commented 8 months ago

I believe this is fixed by #154.

NickGeneva commented 8 months ago

I believe so, seems to be solving the issue. The original source is from get_initial_condition_for_model always using in_channels in the original.

Maybe @pgarg7 you can look at the method implemented and add this?

https://github.com/NVIDIA/earth2mip/blob/6696517588eb3f00c19c24ff8028de85dcc236d1/earth2mip/inference_medium_range.py#L151

https://github.com/NVIDIA/earth2mip/blob/6696517588eb3f00c19c24ff8028de85dcc236d1/earth2mip/initial_conditions/__init__.py#L58

I would personally be for switching out the three paramters:

time: datetime.datetime,
time_levels: int,
time_step: datetime.timedelta = datetime.timedelta(hours=0),

for just one:

times: Union[datetime.datetime, list[datetime.datetime]]
nbren12 commented 8 months ago

I would personally be for switching out the three paramters:

time: datetime.datetime,
time_levels: int,
time_step: datetime.timedelta = datetime.timedelta(hours=0),

for just one:

times: Union[datetime.datetime, list[datetime.datetime]]

Not opposed, but can we avoid the union type? It add some convenience for interactive use, but also some ambiguity. For example, what would the shape of the output be in the datetime.datetime case? one could argue it shouldn't be (B, C, ...) rather than (B, history, C, ...).

pgarg7 commented 8 months ago

I agree! I can give a shot to implement the method mentioned above and see if it improves the generalizability of the scoring module.

NickGeneva commented 8 months ago

Not opposed, but can we avoid the union type?

I'm for it, the union was there to just allow some slight convenience for a single time. But just requiring list[datatime] I think is fine, also makes the API more clear.