HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 5 forks source link

Staged batchsize training #80

Open ClashLuke opened 1 year ago

ClashLuke commented 1 year ago

Some papers such as "Don't Decay the Learning Rate, Increase the Batch Size" have shown that training with progressively larger batch sizes instead of progressively lower learning rates helps models find a better local minimum by improving stability in the final stages of training. Additionally, this increases training speed, as the model gets progressively faster (in tokens/s) with increasing batch size.\ Intuitively, this allows the model to take many small updates initially, as all samples in the batch will point in a similar direction. However, during later stages of the training, the gradients might point in different directions, so larger batches (or lower learning rates) are required.