aditya-grover / climate-learn

Source code for ClimateLearn
MIT License
310 stars 49 forks source link

DataModule error when using 1-pressure level variable #55

Closed noeliaof closed 1 year ago

noeliaof commented 1 year ago

Describe the bug

I tried to use DataModule, for geopotential at 500hPa and also for temperature at 850hPa, as done for surface variables, e.g., 2m temperature or total precipitation. However, when using DataModule for one variable at one pressure level (e.g., geopotential_500), it returns an error when trying to load the data (from load_from_nc).

To Reproduce Steps to reproduce the behavior:

from climate_learn.utils.datetime import Year, Days, Hours
from climate_learn.data import DataModule

    data_module = DataModule(
        dataset = "ERA5",
        task = "forecasting",
        root_dir = DATADIR,
        in_vars = ["geopotential"],
        out_vars = ["geopotential"],
        train_start_year = Year(2015),
        val_start_year = Year(2016),
        test_start_year = Year(2017),
        end_year = Year(2018),
        pred_range = Days(3),
        subsample = Hours(1),
        batch_size = 128,
        num_workers = 1
    )

The error I got:

      ----> 5 data_module = DataModule(
            6     dataset = "ERA5",
            7     task = "forecasting",
            8     root_dir = DATADIR,
            9     in_vars = ["geopotential"],
           10     out_vars = ["geopotential"],
           11     train_start_year = Year(2015),
           12     val_start_year = Year(2016),
           13     test_start_year = Year(2017),
           14     end_year = Year(2018),
           15     pred_range = Days(3),
           16     subsample = Hours(1),
           17     batch_size = 128,
           18     num_workers = 1
           19 )

      File ~/.conda/envs/pyTT/lib/python3.10/site-packages/climate_learn/data/module.py:112, in DataModule.__init__(self, dataset, task, root_dir, in_vars, out_vars, train_start_year, val_start_year, test_start_year, end_year, root_highres_dir, history, window, pred_range, subsample, batch_size, num_workers, pin_memory)
          109 caller = eval(f"{dataset.upper()}{task_string}")
          111 train_years = range(train_start_year, val_start_year)
      --> 112 self.train_dataset = caller(
          113     root_dir,
          114     root_highres_dir,
          115     in_vars,
          116     out_vars,
          117     history,
          118     window,
          119     pred_range.hours(),
          120     train_years,
          121     subsample.hours(),
          122     "train",
          123 )
          125 val_years = range(val_start_year, test_start_year)
          126 self.val_dataset = caller(
          127     root_dir,
          128     root_highres_dir,
         (...)
          136     "val",
          137 )

      File ~/.conda/envs/pyTT/lib/python3.10/site-packages/climate_learn/data/modules/era5_module.py:113, in ERA5Forecasting.__init__(self, root_dir, root_highres_dir, in_vars, out_vars, history, window, pred_range, years, subsample, split)
           99 def __init__(
          100     self,
          101     root_dir,
         (...)
          110     split="train",
          111 ):
          112     print(f"Creating {split} dataset")
      --> 113     super().__init__(root_dir, root_highres_dir, in_vars, years, split)
          115     self.in_vars = list(self.data_dict.keys())
          116     self.out_vars = out_vars

      File ~/.conda/envs/pyTT/lib/python3.10/site-packages/climate_learn/data/modules/era5_module.py:28, in ERA5.__init__(self, root_dir, root_highres_dir, variables, years, split)
           25 self.years = years
           26 self.split = split
      ---> 28 self.data_dict = self.load_from_nc(self.root_dir)
           29 if self.root_highres_dir is not None:
           30     self.data_highres_dict = self.load_from_nc(self.root_highres_dir)

      File ~/.conda/envs/pyTT/lib/python3.10/site-packages/climate_learn/data/modules/era5_module.py:69, in ERA5.load_from_nc(self, data_dir)
           67 if len(xr_data.shape) == 3:  # 8760, 32, 64
           68     xr_data = xr_data.expand_dims(dim="level", axis=1)
      ---> 69     data_dict[var].append(xr_data)
           70 else:  # pressure level
           71     for level in DEFAULT_PRESSURE_LEVELS:

      KeyError: 'geopotential'

I checked in some more detail era5_module.py, it seems (to me) that there might be a bug in the for at line 60:

if len(xr_data.shape) == 3: # 8760, 32, 64 xr_data = xr_data.expand_dims(dim="level", axis=1) data_dict[var].append(xr_data) in this case, there is no level considered when using a variable at only one pressure level.

or did I miss something in the use case of DataModule? Thanks!

tung-nd commented 1 year ago

Hi Noelia. Can you try setting out_vars = ["geopotential_500"] (but left the in_vars the same)? This is because when preprocessing the data, we distinguish different pressure levels of the same variable by appending the level after the variable name.

To avoid confusion, we'll modify this to allow users to specify pressure levels for both input variables and output variables. But you can use this hot fix for now.

noeliaof commented 1 year ago

Hi Tung. I tried the setting:

data_module = DataModule(
    dataset = "ERA5",
    task = "forecasting",
    root_dir = DATADIR,
    in_vars = ["geopotential"],
    out_vars = ["geopotential_500"],
    train_start_year = Year(2015),
    val_start_year = Year(2016),
    test_start_year = Year(2017),
    end_year = Year(2018),
    pred_range = Days(3),
    subsample = Hours(1),
    batch_size = 128,
    num_workers = 1
)

I also tried others settings, as I wasn't sure (both in_vars and out_vars "geopotential_500"), but I still get the same error.

      67 if len(xr_data.shape) == 3:  # 8760, 32, 64
     68     xr_data = xr_data.expand_dims(dim="level", axis=1)
---> 69     data_dict[var].append(xr_data)

In my data folder (also coming from WeatherBench) the geopotential only has 1 level, geopotential_500 (similarly for temperature, temperature_850). I think the problem might be because it's not going through the else (line 70 in era5_module.py), and then, it complains because it doesn't find the level (as it does when using datadict[f"{var}{level}"])... I might be wrong though, but just in case you might want to check.

Thanks again!

tung-nd commented 1 year ago

Oh yes that was the problem. We were assuming that if you wanted to use geopotential or temperature you would use the multi-level data provided by Weatherbench, not the geopotential_500 and temperature_850 directories. We'll document this to avoid future confusion. In the meantime can you try downloading the multi-level geopotential and temperature from Weatherbench and see if it solves the problem?

noeliaof commented 1 year ago

OK, I could fix that, by adding an additional check, but out_vars must be out_vars = ["geopotential_500"]. However, I realised that there might be another problem if I correctly understood the logic of DataModule: When in_vars and out_vars are different, the DataModule fails. I'm assuming that out_vars is the target variable, so in principle in_vars and out_vars might be different. In that case, the class ERA5 only use "variables", which works fine if in_vars and out_vars are the same, otherwise it would fail ... or am I wrong here? are in_vars and out_vars supposed to be the same??

jasonjewik commented 1 year ago

@noeliaof, you are correct in your understanding of in_vars and out_vars. As for whether they should be the same... that was an assumption we made when first writing this code. In hindsight, probably not the best decision.

In any case, this problem has been noticed before and brought to our attention in issue #50. PR #51 proposes changing the code so that out_vars would not need to be a subset of in_vars. It is currently in review.

prakhar6sharma commented 1 year ago

@jasonjewik should we keep the issue open? Given #51 is merged now.

jasonjewik commented 1 year ago

@prakhar6sharma thanks for reminding me. @noeliaof is the bug resolved on the latest commit?

noeliaof commented 1 year ago

hi, I just checked this (with latest commit), and now it's working when in_vars and out_vars are different, but I am afraid the error still happens when having a variable with one single level (e.g., only geopotential_500 downloaded):

        KeyError                                  Traceback (most recent call last)
        Cell In[8], line 1
        ----> 1 data_module = DataModule(
              2     dataset = "ERA5",
              3     task = "forecasting",
              4     root_dir = DATADIR,
              5     in_vars = ["2m_temperature"],
              6     out_vars = ["geopotential"],
              7     train_start_year = Year(2015),
              8     val_start_year = Year(2016),
              9     test_start_year = Year(2017),
             10     end_year = Year(2018),
             11     pred_range = Days(3),
             12     subsample = Hours(1),
             13     batch_size = 128,
             14     num_workers = 1
             15 )

        File ~/climate-learn/src/climate_learn/data/module.py:112, in DataModule.__init__(self, dataset, task, root_dir, in_vars, out_vars, train_start_year, val_start_year, test_start_year, end_year, root_highres_dir, history, window, pred_range, subsample, batch_size, num_workers, pin_memory)
            109 caller = eval(f"{dataset.upper()}{task_string}")
            111 train_years = range(train_start_year, val_start_year)
        --> 112 self.train_dataset = caller(
            113     root_dir,
            114     root_highres_dir,
            115     in_vars,
            116     out_vars,
            117     history,
            118     window,
            119     pred_range.hours(),
            120     train_years,
            121     subsample.hours(),
            122     "train",
            123 )
            125 val_years = range(val_start_year, test_start_year)
            126 self.val_dataset = caller(
            127     root_dir,
            128     root_highres_dir,
           (...)
            136     "val",
            137 )

        File ~/climate-learn/src/climate_learn/data/modules/era5_module.py:114, in ERA5Forecasting.__init__(self, root_dir, root_highres_dir, in_vars, out_vars, history, window, pred_range, years, subsample, split)
            112 print(f"Creating {split} dataset")
            113 unique_vars = list(set(in_vars) | set(out_vars))
        --> 114 super().__init__(root_dir, root_highres_dir, unique_vars, years, split)
            116 self.in_vars = list(self.data_dict.keys())
            117 self.out_vars = out_vars

        File ~/climate-learn/src/climate_learn/data/modules/era5_module.py:28, in ERA5.__init__(self, root_dir, root_highres_dir, variables, years, split)
             25 self.years = years
             26 self.split = split
        ---> 28 self.data_dict = self.load_from_nc(self.root_dir)
             29 if self.root_highres_dir is not None:
             30     self.data_highres_dict = self.load_from_nc(self.root_highres_dir)

        File ~/climate-learn/src/climate_learn/data/modules/era5_module.py:69, in ERA5.load_from_nc(self, data_dir)
             67 if len(xr_data.shape) == 3:  # 8760, 32, 64
             68     xr_data = xr_data.expand_dims(dim="level", axis=1)
        ---> 69     data_dict[var].append(xr_data)
             70 else:  # pressure level
             71     for level in DEFAULT_PRESSURE_LEVELS:

        - KeyError: 'geopotential'

The problem comes from the way data_dic is built, and it should check whether the variable is in PRESSURE_LEVEL_VARS. Then, check it again in the class ERA5Forecasting. I made myself a couple of changes to make it work, but the solution is not ideal ...

tung-nd commented 1 year ago

We are working on refactoring the data loading part, which will resolve this problem. Will update with you when it's done

prakhar6sharma commented 1 year ago

We are working on refactoring the data loading part, which will resolve this problem. Will update with you when it's done

@tung-nd #68 is merged but it still doesn't resolve the way the data_dict is built. Can you please create a separate issue describing in more detail the exact way how data_dict should built.