mllam / neural-lam

Neural Weather Prediction for Limited Area Modeling
MIT License
64 stars 24 forks source link

Handling checkpoint-breaking changes #48

Open joeloskarsson opened 3 weeks ago

joeloskarsson commented 3 weeks ago

Background

As we make more changes to the code there will be points where checkpoints from saved models can not be directly loaded in a newer version of neural-lam. This happens in particular if we start making changes to variable names of nn.Module attributes and the overall structure of the model classes. It would be good to have a policy of how we want to handle such breaking changes. This issue is for discussing this.

Proposals

I see three main options:

  1. Ignore this issue, and only guarantee that checkpoints trained in a specific version of neural-lam works with that version. If you upgrade you have to re-train models or do some "surgery" on your checkpoints files yourself.
  2. Make sure that we can load checkpoints from all previous versions. This is doable as long as the same neural network parameters are in there, just with different names. We have an example of this already, in the current ARModel: https://github.com/mllam/neural-lam/blob/9d558d1f0d343cfe6e0babaa8d9e6c45b852fe21/neural_lam/models/ar_model.py#L576-L596
  3. Create a separate script for converting checkpoint files from one version to another. The required logic for this is the same as in point 2, but here moved to a separate script that takes a checkpoint file as input and saves a new checkpoint file, now compatible with the new neural-lam version.

Considerations for point 2 and 3

My view

Tagging @leifdenby and @sadamov to get your input.

sadamov commented 3 weeks ago

These are some very important considerations. I myself have angered some colleagues by making old checkpoints unusable. Now I am also looking at #49 which would introduce much more flexibility to the user wrt model choices. Mostly for that reason and because I don't think we have the human-power to assure backwards compatibility I am leaning towards option 1. Maybe in the future with a more stable repo + more staff we can implement 3? What I would do now is very solid logging with wandb:

With such information every checkpoint should be usable for a long time. Maybe I am very much overestimating how much time 3 would require. If that is the case I gladly change my opinion.

joeloskarsson commented 3 weeks ago

I am a bit unsure myself about how much work it would really be. As long as we only rename members or change the hierarchy of nn.Modules then it just boils down to renaming keys in the state dict. This we already have an implementation for here: https://github.com/mllam/neural-lam/blob/9d558d1f0d343cfe6e0babaa8d9e6c45b852fe21/neural_lam/models/ar_model.py#L584-L596 It just has to be generalized to more than g2m_gnn.grid_mlp.0.weight.

When things can get tricky is if we reorder input features or change dimensionalities of something. But thinking about this a bit more now I realize: