openclimatefix / metnet

PyTorch Implementation of Google Research's MetNet and MetNet-2
MIT License
237 stars 49 forks source link

Pytorch-Lightning #21

Open ValterFallenius opened 2 years ago

ValterFallenius commented 2 years ago

Pull Request

Description

Converted to pytorch-lightning for easy parallelization. Needs work on the forward pass. How do we create a forward pass dependent on a random lead time and at the same time couple this to the loss function so we pair it with the correct slice of the Y-tensor? Also the paper mentions that they do random lead times during training, but for validation the check every possible lead-time once.

Right now the code creates forecast_steps = 6 copies of the input tensor and concatenates it with 6 different one-hot encodings, this leads to copying the input tensor 6 times which will be unacceptable for larger lead times. I want to go back to using self.ct(x, random_int) in the forward pass to reduce memory usage, but I haven't figured out how to do this and at the same time pair it with the correct ground truth.

Dataset

There's a subsample of my dataset linked in the README, this is not the "prettified" fields that I will use for the thesis, but it has the same dimensions and value range so is good enough for testing. The dataset are DBZ along with longitude, lattitude and elevation encodings (southern half of sweden).

jacobbieker commented 2 years ago

First of all, thanks for submitting this! And releasing some of the data! Would it be okay if I rehost it on HuggingFace Datasets the sample you released? Even not being the 'prettified' data, and only a subsample, it still could be quite nice to have some place other than the US for training data. But if not, that's fine too!

For training, one idea would be in the dataloader randomly pick the lead time, and only load that future time slice as the target, returning the history frames, the single future lead time frame, and the lead time index to the model, if that makes sense? That should reduce the memory requirements quite a bit.

https://github.com/openclimatefix/metnet/pull/20 changes MetNet and MetNet-2 to only use a single lead time by default, you can look at the test_model.py to see how to predict for multiple, which should help with memory consumption.

For the overall PR, I still need to go through it a bit more, but ideally, we want to keep this repo quite general and with minimal dependencies for importing the models, so I would recommend potentially removing the scripts and such that are specific to running on your machine or your cluster, and potentially moving all of the rest of these under a train/ folder instead. So that this adds an example of how to wrap the model in PyTorch Lightning, potentially load some data, and train it from scratch on some data that we can then point to in the documentation. And then we can have an optional dependency on PyTorch Lightning for people who want to use it, while needing only plain PyTorch for the basic models.

ValterFallenius commented 2 years ago

Let me get back to you on the dataset. My supervisor still haven't answered me... in the meantime you can use it yourself.

ValterFallenius commented 2 years ago

Damn, I just solved the conditional lead time problem. It was such an easy fix too... i feel stupid.

I didn't change the data loader, instead I added a random integer into the training loop and fed it as input to the forward pass, then sliced the Y-tensor accordingly.

Berzelius is under maintenance yesterday and today, I hope to run the network next week with some TPUs 😄