google-deepmind / graphcast

Apache License 2.0
4.36k stars 537 forks source link

Training: Learning Rate schedule based on iterations rather than epochs #80

Closed gacuervol closed 1 month ago

gacuervol commented 1 month ago

Description:

We're replicating the training process for the model described in the Lam et al. (2023) paper. In section 4.5 ("Curriculum training schedule"), the authors discuss three training phases. During phase 2, a specific method is used to adjust the learning rate based on the number of iterations (the number of times a each batch is processed). This adjustment uses a cosine decay function.

In the paper, the learning rate is adjusted based on iterations rather than the number of epochs (complete passes through the entire dataset). We're curious why the authors chose this approach.

Here's why this matters: If we follow the approach in the paper and use iterations, there's a chance the learning rate could change in the middle of an epoch. This means it could update before all the data gets used once (especially relevant for our smaller dataset with fewer iterations).

Our question: Is there a benefit to adjusting the learning rate based on iterations instead of epochs? We'd appreciate any insights or explanations you can provide.

Labels:

training learning rate Training schedule

alvarosg commented 1 month ago

Thanks for your message, so actually it is not just that we adjusted the learning rate as a function of iteration, rather than epoch, but actually also did not even train with proper epochs. We simply, at each iteration, sampled 32 examples from the longer ground truth trajectory, randomly and without replacement. So we only have "epochs" on average, but nothing prevented the same sequence to be sampled twice before some other sequence was sample once.

We believed this did not matter in our case because an epoch is about 54k training examples, and we trained with 300k steps batch 32 each, which corresponds to about ~180 epochs, so it probably does not matter if some examples are sampled a bit more than 180 times, and some other a bit less than 180 times.

In the case of a small dataset probably it still does not matter much so long as the number of training iterations is still large compared to the dataset size, and the learning rate decay is very slow (like in our case), but of course if you are getting to limit where each example is only seen a few times, and the learning rate decays quickly, I think it makes sense to do it in the way you are proposing.

gacuervol commented 1 month ago

Thank you for providing clarification and detailing the training process. Understanding the methodology behind your model training is insightful. Your point about smaller datasets and quicker learning rate decay suggests that an approach based on epochs might be more suitable in those scenarios.

It seems that the issue has been satisfactorily resolved.