aditya-grover / climate-learn

Source code for ClimateLearn
MIT License
310 stars 49 forks source link

Training on 3D variables do not work anymore #36

Closed se0ngbin closed 1 year ago

se0ngbin commented 1 year ago

Hello, I recently tried to load ERA5 using the updated climate-learn package:

era5_data_module = DataModule(
    dataset = "ERA5",
    task = "forecasting",
    root_dir = era_path,
    in_vars = ["temperature"],
    out_vars = ["temperature"],
    train_start_year = Year(1979),
    val_start_year = Year(2011),
    test_start_year = Year(2013),
    end_year = Year(2014),
    pred_range = Days(5),
    subsample = Hours(6),
    batch_size = 128,
    num_workers = 1
)

Running the above code produces the following error:

KeyError                                  Traceback (most recent call last)
Cell In [16], line 1
----> 1 era5_data_module = DataModule(
      2     dataset = "ERA5",
      3     task = "forecasting",
      4     root_dir = era_path,
      5     in_vars = ["temperature"],
      6     out_vars = ["temperature"],
      7     train_start_year = Year(1979),
      8     val_start_year = Year(2011),
      9     test_start_year = Year(2013),
     10     end_year = Year(2014),
     11     pred_range = Days(5),
     12     subsample = Hours(6),
     13     batch_size = 128,
     14     num_workers = 1
     15 )

File ~/climate-learn/src/climate_learn/data/module.py:58, in DataModule.__init__(self, dataset, task, root_dir, in_vars, out_vars, train_start_year, val_start_year, test_start_year, end_year, root_highres_dir, history, window, pred_range, subsample, batch_size, num_workers, pin_memory)
     55 caller = eval(f"{dataset.upper()}{task_string}")
     57 train_years = range(train_start_year, val_start_year)
---> 58 self.train_dataset = caller(
     59     root_dir,
     60     root_highres_dir,
     61     in_vars,
     62     out_vars,
     63     history,
     64     window,
     65     pred_range.hours(),
     66     train_years,
     67     subsample.hours(),
     68     "train",
     69 )
     71 val_years = range(val_start_year, test_start_year)
     72 self.val_dataset = caller(
     73     root_dir,
     74     root_highres_dir,
   (...)
     82     "val",
     83 )

File ~/climate-learn/src/climate_learn/data/modules/era5_module.py:122, in ERA5Forecasting.__init__(self, root_dir, root_highres_dir, in_vars, out_vars, history, window, pred_range, years, subsample, split)
    119 self.pred_range = pred_range
    121 inp_data = xr.concat([self.data_dict[k] for k in self.in_vars], dim="level")
--> 122 out_data = xr.concat([self.data_dict[k] for k in self.out_vars], dim="level")
    123 self.inp_data = inp_data.to_numpy().astype(np.float32)
    124 self.out_data = out_data.to_numpy().astype(np.float32)

File ~/climate-learn/src/climate_learn/data/modules/era5_module.py:122, in <listcomp>(.0)
    119 self.pred_range = pred_range
    121 inp_data = xr.concat([self.data_dict[k] for k in self.in_vars], dim="level")
--> 122 out_data = xr.concat([self.data_dict[k] for k in self.out_vars], dim="level")
    123 self.inp_data = inp_data.to_numpy().astype(np.float32)
    124 self.out_data = out_data.to_numpy().astype(np.float32)

KeyError: 'temperature'

Same problem happens with other pressure-level variables such as geopotential.

Hritikbansal commented 1 year ago

Hi @se0ngbin, you would have to mention the level of temperature variable to want in your data. Simply stating "temperature" would not work. @jasonjewik can you take a look here?

jasonjewik commented 1 year ago

@seongbin from what source did you download your dataset? The only source that provides climate variables at multiple pressure levels is Copernicus.

Currently, ClimateLearn lets you download such data by specifying pressure=True as a keyword argument to the data.download function. Then, it will download your requested variable at pressure levels 1000, 850, 500, and 50. Details can be found here: https://github.com/aditya-grover/climate-learn/blob/1066687be76272989c5e8710d729b5763d2191b0/src/climate_learn/data/download.py#L13

se0ngbin commented 1 year ago

The mapping of xarray to numpy (self.inp_data = inp_data.to_numpy().astype(np.float32)) is taking too long to know for sure now, but I believe the following code snippet should work:

from climate_learn.utils.datetime import Year, Days, Hours
from climate_learn.data import DataModule

era5_data_module = DataModule(
    dataset = "ERA5",
    task = "forecasting",
    root_dir = era_path,
    in_vars = ["geopotential"],
    out_vars = ["geopotential_500"],
    train_start_year = Year(1979),
    val_start_year = Year(2011),
    test_start_year = Year(2013),
    end_year = Year(2014),
    pred_range = Days(5),
    subsample = Hours(6),
    batch_size = 128,
    num_workers = 16
) 

in_vars only accepts variables that are in the lists CONSTANTS, PRESSURE_LEVEL_VARS, and SINGLE_LEVEL_VARS. This means that I can not choose in_vars to be a specific level of a 3D variable, such as Z500. Instead, I have to choose the whole 3D variable, like geopotential.

Internally, the DataModule creates a dictionary with the keys {var}_{level}, with all of the levels in DEFAULT_PRESSURE_LEVELS. Then, the inp_data and out_data arrays are created:

# self.data_dict built up here

self.in_vars = list(self.data_dict.keys()) # all keys of created dict
self.out_vars = out_vars # direct input from user

# lines of code omitted

# this is always fine
inp_data = xr.concat([self.data_dict[k] for k in self.in_vars], dim="plev")
# however, out_data must only consist of keys in data_dict
out_data = xr.concat([self.data_dict[k] for k in self.out_vars], dim="plev")

As seen above, all variables inout_data must be in self.data_dict, so would have the form {var}_{level}. I feel like this is a confusing design, since I cannot specify level for in_vars but must specify one for out_vars.

tung-nd commented 1 year ago

Thanks Seongbin for raising this problem. It is indeed confusing. I am writing a new data module that uses Iterabledataset, which will also fix this problem. We will be able to specify the variables and the pressure level in both the input and output, so it will be like this:

era5_data_module = DataModule(
    dataset = "ERA5",
    task = "forecasting",
    root_dir = era_path,
    in_vars = ["geopotential_500"],
    out_vars = ["geopotential_500"],
    train_start_year = Year(1979),
    val_start_year = Year(2011),
    test_start_year = Year(2013),
    end_year = Year(2014),
    pred_range = Days(5),
    subsample = Hours(6),
    batch_size = 128,
    num_workers = 16
) 

Does this sound good? @se0ngbin @jasonjewik @Hritikbansal