mllam / neural-lam

Neural Weather Prediction for Limited Area Modeling
MIT License
64 stars 24 forks source link

Refactor model class hierarchy #49

Open joeloskarsson opened 3 weeks ago

joeloskarsson commented 3 weeks ago

Background

The different models that can be trained in Neural-LAM are currently all sub-classes of pytorch_lightning.LightningModule. In particular, much of the functionality sits in the first subclass, ARModel. The current hierarchy looks like this: classes_current (I am making these rather than some more fancy UML-diagrams since I think this should be enough for the level of detail we need to discuss here).

The problem

In the current hierarchy everything is a subclass of ARModel. This has a number of drawbacks:

Proposed new hierarchy

I propose to split up the current class hierarchy into subclasses that have clear responsibilities. These should not just all inherit ARModel, but rather be members of each other as suitable. A first idea for this is shown below, including also potential future classes for new models (to show how this is more extendible):

classes_proposal

The core components are (I here count static features as part of forcing):

In the figure above we can also see how new kinds of models could fit into this hierarchy:

This is supposed to be a starting point for discussion and there will likely be things I have not thought about. Some parts of this will have to be hammered out when actually writing these classes, but I'd rather have the discussion whether this is a good direction to take things before starting to do too much work. Tagging @leifdenby and @sadamov for visibility.

sadamov commented 3 weeks ago

@joeloskarsson This looks great! Introducing even more freedom and clearly structured and understandable classes. I am no expert in this so not adding too much here. I really think though, that relieving ar_models.py from some of its duties will make our lives easier. Also, you already have experience with probabilistic Neural-LAM so I trust you to make the changes necessary to implement that in the future. :heart:

leifdenby commented 2 weeks ago

I have some time to give this a proper think through now @joeloskarsson so here come my thoughts :)

  • All functionality ends up in one class. There is no clear division of responsibility, but rather we end up with one class that does everything from logging, device-handling, unrolling forecasts and the actual forward and backward pass through GNNs.

  • The subclasses do not utilize the standard torch forward calls, but rather must resorts to our own similar construction (e.g. predict_step)

  • This limits us to deterministic, auto-regressive models.

  • This is hard to extend upon for ensemble-models.

Yes, agreed! Thank you for making that diagram too. It is very helpful in understanding the current structure.

  • ForecasterModule: Takes over much of the responsibility of the old ARModel. Handles things not directly related to the nerual network components such as plotting, logging, moving batches to the right device. This inherits pytorch_lightning.LightningModule and have the different train/val/test steps. In each step (train/val/test), unpacks the batch of tensors and uses a Forecaster to produce a full forecast. Also responsible for computing the loss based in a produced forecast (could also be in Forecaster, not entirely sure about this).

This sounds good. My only suggestion would be to not include the plotting as a method of the ForecasterModule or Forecaster classes, but instead keep this a separate (and therefore stateless) function. Keeping plotting separately would allow us to make the model architecture classes more lean and design for reuse of plotting functions. Regarding where to store the loss operation: I think I would place this with Forecaster since that class only handles forecasting for a given time-window, whereas ForecasterModule would also handle the data-handling used for constructing time-windows.

I think the structure you have proposed with ARForecaster, StepPredictor and the current (and future) graph-based models makes a lot of sense too, I really like it. And I see how new architectures can be encorporated, but not completely. I was thinking for example for the BaseGraphModel family of methods vs a future CNNPredictor, these will have different requirements in terms of what a single batch should contain (spatial coordinates stacked or not). Have you thought about at which level this will impact the forward methods of this new hierarchy? Maybe this is self-evident, but to me it is quite tricky. It might simply be a question of well-written docstrings and a convention (in terms of which dimensions in the input tensor represent what) that we apply through the code. Minor thought: should it maybe be BaseGraphModelPredictor rather than BaseGraphModel or something like? Just to follow the *Predictor name convention you are setting up for here.

Maybe when you/we start coding this we could start by including the additions you forse making (ensemble prediction classes, CNN, ViT) as simply empty placeholder classes? Then we can be sure the data-structures we pass around will be general enough to apply in future too?

In general this is really fantastic though.

joeloskarsson commented 2 weeks ago

Thanks for some very valuable input @leifdenby!

My only suggestion would be to not include the plotting as a method of the ForecasterModule or Forecaster classes, but instead keep this a separate (and therefore stateless) function.

Yes, that is what I have in mind as well. In the same way it is handled today, where ARModel just makes calls to plotting functions in neural_lam.vis. Potentially even more surrounding logic could be moved away from the ForecasterModule, but will see if this is would be the correct PR for that.

Regarding where to store the loss operation: I think I would place this with Forecaster since that class only handles forecasting for a given time-window, whereas ForecasterModule would also handle the data-handling used for constructing time-windows.

There is a logic to that yes. I am still thinking about how to do this in practice. I am in a way not particularly happy to send the target (ground truth) tensor further down the hierarchy than ForecasterModule, which would be required to compute the loss in Forecaster. In practice one would either have to 1) Let the forward function of Forecaster map from inputs + target tensor to the loss value. This violates what one would expect a forward call to do (produce the output/prediction). 2) Have a separate function in Forecaster for computing the loss, which means we have to again define our own interface for parts of this, rather than relying completely on Lightnings train_step. It's a bit tricky, maybe easiest to start hacking away at and see what becomes the most convenient.

I was thinking for example for the BaseGraphModel family of methods vs a future CNNPredictor, these will have different requirements in terms of what a single batch should contain (spatial coordinates stacked or not). Have you thought about at which level this will impact the forward methods of this new hierarchy? Maybe this is self-evident, but to me it is quite tricky. It might simply be a question of well-written docstrings and a convention (in terms of which dimensions in the input tensor represent what) that we apply through the code.

Yes, this is a good point. I have given this some thought. One should note that having 1 (flattened) spatial dimensions or 2 (x,y) should not majorly impact anything before the StepPredictors. Since moving between these is just a cheap reshape, it is mostly a matter of convention. You can either flatten already in the dataset class and then reshape to 2 spatial dims. in e.g. CNN-models (current setup), or you keep your spatial dims. from the dataloader and flatten in the graph-based models. This is something that should not be very hard to change later on if needed.

Here is an example of a dummy CNN-class I have laying around that handles this with reshapes (using old class hierarchy):

import torch

from neural_lam.models.ar_model import ARModel
from neural_lam import constants

class NewModel(ARModel):
    """
    A new auto-regressive weather forecasting model
    """
    def __init__(self, args):
        super().__init__(args)

        # Some dimensionalities that can be useful to have stored
        self.input_dim = 2*constants.grid_state_dim + constants.grid_forcing_dim +\
            constants.batch_static_feature_dim
        self.output_dim = constants.grid_state_dim

        # TODO: Define modules as members here that will be used in predict_step
        self.layer = torch.nn.Conv2d(self.input_dim, self.output_dim, 1) # Dummy layer

    def predict_step(self, prev_state, prev_prev_state, batch_static_features, forcing):
        """
        Predict weather state one time step ahead
        X_{t-1}, X_t -> X_t+1

        prev_state: (B, N_grid, d_state), weather state X_t at time t
        prev_prev_state: (B, N_grid, d_state), weather state X_{t-1} at time t-1
        batch_static_features: (B, N_grid, batch_static_feature_dim), static forcing
        forcing: (B, N_grid, forcing_dim), dynamic forcing

        Returns:
        next_state: (B, N_grid, d_state), predicted weather state X_{t+1} at time t+1
        pred_std: None or (B, N_grid, d_state), predicted standard-deviations
                    (pred_std can be ignored by just returning None)
        """

        # Reshape 1d grid to 2d image
        input_flat = torch.cat((prev_state, prev_prev_state, batch_static_features,
            forcing), dim=-1) # (B, N_grid, d_input)
        input_grid = torch.reshape(input_flat, (-1, *constants.grid_shape,
            input_flat.shape[2])) # (B, N_x, N_y, d_input)
        # Most computer vision methods in torch want channel dimension first
        input_grid = input_grid.permute((0,3,1,2)).contiguous() # (B, d_input, N_x, N_y)

        # TODO: Feed input_grid through some model to predict output_grid
        output_grid = self.layer(input_grid) # Shape (B, d_state, N_x, N_y)

        # Reshape back from 2d to flattened grid dimension
        output_grid = output_grid.permute((0,2,3,1)) # (B, N_x, N_y, d_state)
        next_state = output_grid.flatten(1,2) # (B, N_grid, d_state)

        return next_state, None

A good reason to keep it as is, flattening in the Dataset class, is that once we start moving to more refined boundary setups, on different gridding, there will no longer be a 2d grid representation of the input data. Grid cells will be on irregular grids, meaning that it is not trivial to apply CNN models to them. Keeping things flattened all the way to the CNN model means that if you want to define a CNN on this you have to decide how to handle this irregularity. We don't have to make such decisions before the forward call in the CNN class.

Minor thought: should it maybe be BaseGraphModelPredictor rather than BaseGraphModel or something like? Just to follow the *Predictor name convention you are setting up for here.

Yes, that is a good idea. Or even BaseGraphPredictor and BaseHiGraphPredictor. Not sure if one would want to also name the final models (e.g. HiLAM) something like HiLAMPredictor, just to specify that these are also StepPredictors? On the other hand, it's maybe best to keep to the exact naming of a model used in e.g. a paper.

Maybe when you/we start coding this we could start by including the additions you forse making (ensemble prediction classes, CNN, ViT) as simply empty placeholder classes? Then we can be sure the data-structures we pass around will be general enough to apply in future too?

This could be smart yes. For the computer-visions models (CNN, ViT) I don't think that there will be much to keep in mind (see dicsussion above), but for the ensemble model this could make a lot of sense. Especially since I anyhow will populate that with the code from the probabilistic model branches later. My plan is to do this refactoring first, to make the merging of that easier and nicer.

Overall it looks like there is support for this idea, so I can start writing the code for it. Then we can discuss more details in upcoming PR. I am happy to do the work on this, but I have a hard time to give a timeline, as it is not directly crucial to progress in ongoing research projects. Anyhow, I think this could potentially fit in v0.3.0 in the roadmap?