Closed ohallstrom closed 1 month ago
I think there's three levels of increasing complexity for implementing batch size warmup:
before_train_batch
event . This could be modeled on the Sequence Length Warmup algorithm. This shouldn't cause any scheduling issues as Composer defines schedulers on the batch or token level.Trainer
.Each setup has some downsides. 1) discards data, 2) is a bit hacky, 3) needs to work with restartable training and other Compsoer features, and 4) potentially locks out future versions of Composer and might not play well with existing callbacks and algorithms.
In addition to reducing complexity, both options 1-3 would allow for more accurate loss logging.
I'd personally lean towards trying out option 1 for simplicity. Although 3 should also work well without data loss.
There have been some very quick ablations run on this branch to investigate the impact of batch rampup, and given the results and a discussion among the team we have decided to not go forward with batch rampup for the large training runs.
This branch has served its purpose for now, and I will therefore close this PR.
Here comes a draft for an implementation of batch rampup, let me know what you think!
Changes
This PR creates a new custom trainer based on the standard Trainer in composer, allowing batch rampup.
In order to ensure even hardware utilization, the way batch rampup can be used with this PR is simply by removing gradient accumulation. Let's say we have
n_d
devices, micro batch sizembs
, global batch sizegbs
, andn_grad_acc
gradient accumulation steps:Then we can do maximum
log(n_grad_acc)
rampup stages. If for example global batch size is 4096 and we have 4 gradient accumulation steps, we can do two rampup stages, starting rampup from gbs 1024, then go to a second rampup stage with gbs 2048 before reaching the final gbs 4096. This should work with our current single node setup as it uses gradient accumulation. However, this rampup implementation is not as suitable for training with many nodes and consequently many devices, as we won't be able to do as much (if any) rampup.Logging is always made every
gbs
sample, regardless if rampup is used or not. Consequently during rampup, the loss is not logged for every step - instead the mean loss for the all steps on the lastgbs
samples is logged. Meaning that if the final global batch size is 4096, and we use gbs 2048 during rampup, we log the mean of the two latest loss on 2048 samples every 4096th sample.To use batch rampup, specify the rampup length using
batch_rampup
incomposer.core.time
format. Optionally, the batch size to initialize the rampup from can be specified withinital_global_train_batch_size
. Else the rampup will start from the lowest batch size possible.Let me know if anything is unclear, if you have any comments or suggestions :)
Discussions
Adding batch rampup as requested in issue #19
Tests
I have verified that the new custom trainer has identical throughput and loss compared to the standard trainer when batch rampup is not used. I have also verified that the custom trainer during the first stage of rampup from gbs 4096 to 8192 has very similar loss compared to no rampup with gbs 4096 (taking the logging different logging frequencies into account). However, I have not jest being able to verify that the custom trainer.
I have also run all tests, and no test failed.