google-deepmind / graphcast

Apache License 2.0
4.37k stars 538 forks source link

Regarding training #12

Closed yogeshverma1998 closed 5 months ago

yogeshverma1998 commented 7 months ago

Hi,

I am trying to train GraphCast on some set of data. Since the main training loop is absent in the repo, I am following the e.ipynb file to create one. The demo file only computes loss for one iteration over a small number of forecasting steps.

How do you train it if you have a large number of steps, like two months of data, as there might be batching of these steps involved? However, I cannot find the function or batching function over the long trajectories using 'data_utils.extract_inputs_targets_forcings' to backpropagate the gradient.

Regards, Yogesh

tewalds commented 7 months ago

We primarily train on ERA5 data from 1979 to 2015, with batch size of 32 (batch size 1 per chip), with our batches sampled uniformly from the training range. Most of our training is a single step (so two inputs, and one output), though we fine tune for a few autoregressive steps. For low resolution inputs and large chips you can use autoregressive.py to backprop through multiple steps (we've done 12 steps, ie 3 days). We evaluate them up to 10 days (ie 40 steps), but have never trained them with that many steps. Backpropping through two month trajectories would require much more compute, memory and engineering than is currently available, or very low resolution inputs.

mjwillson commented 5 months ago

Closing as question appears to be answered. Please also refer to our paper for details about training.