mllam / neural-lam

Neural Weather Prediction for Limited Area Modeling
MIT License
64 stars 25 forks source link

Remove batch-static tensor from dataset class and models #13

Closed joeloskarsson closed 3 months ago

joeloskarsson commented 3 months ago

The batch-static tensor contained forcing that differed between initialization times, but stayed static for all lead times of a forecast. For the MEPS data we used this for the land-water-mask, as this could be different throughout the year, but we could not produce separate values per lead time (as all other forcing).

This PR removes the batch-static features as an explicit extra input. The motivation is:

  1. Having such input features is quite a rare and a highly specific case.
  2. If such inputs exists, it is better to just treat them as any other type of forcing. Then the values have to be repeated over the temporal dimension, but this can either be handled in pre-processing or easily in the Dataset class. In this PR the MEPS Dataset class is changed to take this approach.
  3. Needing to pass around the batch-static features clutter up the code. For most dataset they would not be used, requiring constant special checks for if they are None.

This PR changes:

  1. Bake the batch-static features into the normal forcing in the MEPS Dataset class.
  2. Change the Dataset class to only return 3 tensors per sample (init, target, forcing).
  3. Remove the batch-static tensor from being extracted from the batch and passed around in the graph-based models. This while making sure that input dimensions line up so older checkpoints can still be loaded correctly.
joeloskarsson commented 3 months ago

@sadamov Hope it's ok that I put you to review PRs like this :) I think it's valuable to get a second pair of eyes to look at the changes, and also good for you to get an update on small things I am changing.

The changes to the MEPS Dataset class here are not very important, this is really motivated by moving away from things being too specific for that data.

joeloskarsson commented 3 months ago

Thanks for taking a look!

I just realized I forgot to change create_parameter_weights.py, as the Dataset class is used in there also. Will fix that (should only be a tiny change of index) and then merge.