Open joeloskarsson opened 3 weeks ago
@joeloskarsson This looks great! Introducing even more freedom and clearly structured and understandable classes. I am no expert in this so not adding too much here. I really think though, that relieving ar_models.py
from some of its duties will make our lives easier. Also, you already have experience with probabilistic Neural-LAM so I trust you to make the changes necessary to implement that in the future. :heart:
I have some time to give this a proper think through now @joeloskarsson so here come my thoughts :)
All functionality ends up in one class. There is no clear division of responsibility, but rather we end up with one class that does everything from logging, device-handling, unrolling forecasts and the actual forward and backward pass through GNNs.
The subclasses do not utilize the standard torch
forward
calls, but rather must resorts to our own similar construction (e.g.predict_step
)This limits us to deterministic, auto-regressive models.
This is hard to extend upon for ensemble-models.
Yes, agreed! Thank you for making that diagram too. It is very helpful in understanding the current structure.
ForecasterModule
: Takes over much of the responsibility of the oldARModel
. Handles things not directly related to the nerual network components such as plotting, logging, moving batches to the right device. This inheritspytorch_lightning.LightningModule
and have the different train/val/test steps. In each step (train/val/test), unpacks the batch of tensors and uses aForecaster
to produce a full forecast. Also responsible for computing the loss based in a produced forecast (could also be inForecaster
, not entirely sure about this).
This sounds good. My only suggestion would be to not include the plotting as a method of the ForecasterModule
or Forecaster
classes, but instead keep this a separate (and therefore stateless) function. Keeping plotting separately would allow us to make the model architecture classes more lean and design for reuse of plotting functions. Regarding where to store the loss operation: I think I would place this with Forecaster
since that class only handles forecasting for a given time-window, whereas ForecasterModule
would also handle the data-handling used for constructing time-windows.
I think the structure you have proposed with ARForecaster
, StepPredictor
and the current (and future) graph-based models makes a lot of sense too, I really like it. And I see how new architectures can be encorporated, but not completely. I was thinking for example for the BaseGraphModel
family of methods vs a future CNNPredictor
, these will have different requirements in terms of what a single batch should contain (spatial coordinates stacked or not). Have you thought about at which level this will impact the forward
methods of this new hierarchy? Maybe this is self-evident, but to me it is quite tricky. It might simply be a question of well-written docstrings and a convention (in terms of which dimensions in the input tensor represent what) that we apply through the code. Minor thought: should it maybe be BaseGraphModelPredictor
rather than BaseGraphModel
or something like? Just to follow the *Predictor
name convention you are setting up for here.
Maybe when you/we start coding this we could start by including the additions you forse making (ensemble prediction classes, CNN, ViT) as simply empty placeholder classes? Then we can be sure the data-structures we pass around will be general enough to apply in future too?
In general this is really fantastic though.
Thanks for some very valuable input @leifdenby!
My only suggestion would be to not include the plotting as a method of the
ForecasterModule
orForecaster
classes, but instead keep this a separate (and therefore stateless) function.
Yes, that is what I have in mind as well. In the same way it is handled today, where ARModel
just makes calls to plotting functions in neural_lam.vis
. Potentially even more surrounding logic could be moved away from the ForecasterModule
, but will see if this is would be the correct PR for that.
Regarding where to store the loss operation: I think I would place this with
Forecaster
since that class only handles forecasting for a given time-window, whereasForecasterModule
would also handle the data-handling used for constructing time-windows.
There is a logic to that yes. I am still thinking about how to do this in practice. I am in a way not particularly happy to send the target (ground truth) tensor further down the hierarchy than ForecasterModule
, which would be required to compute the loss in Forecaster
. In practice one would either have to 1) Let the forward
function of Forecaster
map from inputs + target tensor to the loss value. This violates what one would expect a forward
call to do (produce the output/prediction). 2) Have a separate function in Forecaster
for computing the loss, which means we have to again define our own interface for parts of this, rather than relying completely on Lightnings train_step
. It's a bit tricky, maybe easiest to start hacking away at and see what becomes the most convenient.
I was thinking for example for the
BaseGraphModel
family of methods vs a futureCNNPredictor
, these will have different requirements in terms of what a single batch should contain (spatial coordinates stacked or not). Have you thought about at which level this will impact theforward
methods of this new hierarchy? Maybe this is self-evident, but to me it is quite tricky. It might simply be a question of well-written docstrings and a convention (in terms of which dimensions in the input tensor represent what) that we apply through the code.
Yes, this is a good point. I have given this some thought. One should note that having 1 (flattened) spatial dimensions or 2 (x,y) should not majorly impact anything before the StepPredictors
. Since moving between these is just a cheap reshape, it is mostly a matter of convention. You can either flatten already in the dataset class and then reshape to 2 spatial dims. in e.g. CNN-models (current setup), or you keep your spatial dims. from the dataloader and flatten in the graph-based models. This is something that should not be very hard to change later on if needed.
Here is an example of a dummy CNN-class I have laying around that handles this with reshapes (using old class hierarchy):
import torch
from neural_lam.models.ar_model import ARModel
from neural_lam import constants
class NewModel(ARModel):
"""
A new auto-regressive weather forecasting model
"""
def __init__(self, args):
super().__init__(args)
# Some dimensionalities that can be useful to have stored
self.input_dim = 2*constants.grid_state_dim + constants.grid_forcing_dim +\
constants.batch_static_feature_dim
self.output_dim = constants.grid_state_dim
# TODO: Define modules as members here that will be used in predict_step
self.layer = torch.nn.Conv2d(self.input_dim, self.output_dim, 1) # Dummy layer
def predict_step(self, prev_state, prev_prev_state, batch_static_features, forcing):
"""
Predict weather state one time step ahead
X_{t-1}, X_t -> X_t+1
prev_state: (B, N_grid, d_state), weather state X_t at time t
prev_prev_state: (B, N_grid, d_state), weather state X_{t-1} at time t-1
batch_static_features: (B, N_grid, batch_static_feature_dim), static forcing
forcing: (B, N_grid, forcing_dim), dynamic forcing
Returns:
next_state: (B, N_grid, d_state), predicted weather state X_{t+1} at time t+1
pred_std: None or (B, N_grid, d_state), predicted standard-deviations
(pred_std can be ignored by just returning None)
"""
# Reshape 1d grid to 2d image
input_flat = torch.cat((prev_state, prev_prev_state, batch_static_features,
forcing), dim=-1) # (B, N_grid, d_input)
input_grid = torch.reshape(input_flat, (-1, *constants.grid_shape,
input_flat.shape[2])) # (B, N_x, N_y, d_input)
# Most computer vision methods in torch want channel dimension first
input_grid = input_grid.permute((0,3,1,2)).contiguous() # (B, d_input, N_x, N_y)
# TODO: Feed input_grid through some model to predict output_grid
output_grid = self.layer(input_grid) # Shape (B, d_state, N_x, N_y)
# Reshape back from 2d to flattened grid dimension
output_grid = output_grid.permute((0,2,3,1)) # (B, N_x, N_y, d_state)
next_state = output_grid.flatten(1,2) # (B, N_grid, d_state)
return next_state, None
A good reason to keep it as is, flattening in the Dataset class, is that once we start moving to more refined boundary setups, on different gridding, there will no longer be a 2d grid representation of the input data. Grid cells will be on irregular grids, meaning that it is not trivial to apply CNN models to them. Keeping things flattened all the way to the CNN model means that if you want to define a CNN on this you have to decide how to handle this irregularity. We don't have to make such decisions before the forward
call in the CNN class.
Minor thought: should it maybe be
BaseGraphModelPredictor
rather thanBaseGraphModel
or something like? Just to follow the*Predictor
name convention you are setting up for here.
Yes, that is a good idea. Or even BaseGraphPredictor
and BaseHiGraphPredictor
. Not sure if one would want to also name the final models (e.g. HiLAM) something like HiLAMPredictor
, just to specify that these are also StepPredictor
s? On the other hand, it's maybe best to keep to the exact naming of a model used in e.g. a paper.
Maybe when you/we start coding this we could start by including the additions you forse making (ensemble prediction classes, CNN, ViT) as simply empty placeholder classes? Then we can be sure the data-structures we pass around will be general enough to apply in future too?
This could be smart yes. For the computer-visions models (CNN, ViT) I don't think that there will be much to keep in mind (see dicsussion above), but for the ensemble model this could make a lot of sense. Especially since I anyhow will populate that with the code from the probabilistic model branches later. My plan is to do this refactoring first, to make the merging of that easier and nicer.
Overall it looks like there is support for this idea, so I can start writing the code for it. Then we can discuss more details in upcoming PR. I am happy to do the work on this, but I have a hard time to give a timeline, as it is not directly crucial to progress in ongoing research projects. Anyhow, I think this could potentially fit in v0.3.0 in the roadmap?
Background
The different models that can be trained in Neural-LAM are currently all sub-classes of
(I am making these rather than some more fancy UML-diagrams since I think this should be enough for the level of detail we need to discuss here).
pytorch_lightning.LightningModule
. In particular, much of the functionality sits in the first subclass,ARModel
. The current hierarchy looks like this:The problem
In the current hierarchy everything is a subclass of
ARModel
. This has a number of drawbacks:forward
calls, but rather must resorts to our own similar construction (e.g.predict_step
)Proposed new hierarchy
I propose to split up the current class hierarchy into subclasses that have clear responsibilities. These should not just all inherit
ARModel
, but rather be members of each other as suitable. A first idea for this is shown below, including also potential future classes for new models (to show how this is more extendible):The core components are (I here count static features as part of forcing):
ForecasterModule
: Takes over much of the responsibility of the oldARModel
. Handles things not directly related to the nerual network components such as plotting, logging, moving batches to the right device. This inheritspytorch_lightning.LightningModule
and have the different train/val/test steps. In each step (train/val/test), unpacks the batch of tensors and uses aForecaster
to produce a full forecast. Also responsible for computing the loss based in a produced forecast (could also be inForecaster
, not entirely sure about this).Forecaster
: A generic forecaster capable of mapping from a set of initial states, forcing and boundary forcing into a full forecast of the requested length.ARForecaster
: Subclass ofForecaster
that uses an auto-regressive strategy to unroll a forecast. Makes use of aStepPredictor
at each AR step.StepPredictor
: A model mapping from the two previous time steps + forcing + boundary forcing to a prediction of the next state. Corresponds to the $\hat{f}$ function in Oskarsson et al..StepPredictor
.In the figure above we can also see how new kinds of models could fit into this hierarchy:
This is supposed to be a starting point for discussion and there will likely be things I have not thought about. Some parts of this will have to be hammered out when actually writing these classes, but I'd rather have the discussion whether this is a good direction to take things before starting to do too much work. Tagging @leifdenby and @sadamov for visibility.