AnswerDotAI / bert24

Apache License 2.0
60 stars 3 forks source link

Add custom trainer with support for batch rampup #76

Closed ohallstrom closed 1 month ago

ohallstrom commented 3 months ago

Here comes a draft for an implementation of batch rampup, let me know what you think!

Changes

This PR creates a new custom trainer based on the standard Trainer in composer, allowing batch rampup.

In order to ensure even hardware utilization, the way batch rampup can be used with this PR is simply by removing gradient accumulation. Let's say we have n_d devices, micro batch size mbs, global batch size gbs, and n_grad_acc gradient accumulation steps:

mbs * n_d * n_grad_acc = gbs

Then we can do maximum log(n_grad_acc) rampup stages. If for example global batch size is 4096 and we have 4 gradient accumulation steps, we can do two rampup stages, starting rampup from gbs 1024, then go to a second rampup stage with gbs 2048 before reaching the final gbs 4096. This should work with our current single node setup as it uses gradient accumulation. However, this rampup implementation is not as suitable for training with many nodes and consequently many devices, as we won't be able to do as much (if any) rampup.

Logging is always made every gbs sample, regardless if rampup is used or not. Consequently during rampup, the loss is not logged for every step - instead the mean loss for the all steps on the last gbs samples is logged. Meaning that if the final global batch size is 4096, and we use gbs 2048 during rampup, we log the mean of the two latest loss on 2048 samples every 4096th sample.

To use batch rampup, specify the rampup length using batch_rampup in composer.core.time format. Optionally, the batch size to initialize the rampup from can be specified with inital_global_train_batch_size. Else the rampup will start from the lowest batch size possible.

Let me know if anything is unclear, if you have any comments or suggestions :)

Discussions

Adding batch rampup as requested in issue #19

Tests

I have verified that the new custom trainer has identical throughput and loss compared to the standard trainer when batch rampup is not used. I have also verified that the custom trainer during the first stage of rampup from gbs 4096 to 8192 has very similar loss compared to no rampup with gbs 4096 (taking the logging different logging frequencies into account). However, I have not jest being able to verify that the custom trainer.

I have also run all tests, and no test failed.

warner-benjamin commented 3 months ago

I think there's three levels of increasing complexity for implementing batch size warmup:

  1. Create a Composer Algorithm which implements batch size warmup by removing samples from the per device batch during the before_train_batch event . This could be modeled on the Sequence Length Warmup algorithm. This shouldn't cause any scheduling issues as Composer defines schedulers on the batch or token level.
  2. Create a DataLoader wrapper which iterates over multiple DataLoaders with different batch sizes on the first epoch.
  3. Create a DataLoader with a batch size schedule. Since we are scheduling in terms of tokens seen, I think Composer wouldn't complain too much
  4. Modify and subclass the Trainer.

Each setup has some downsides. 1) discards data, 2) is a bit hacky, 3) needs to work with restartable training and other Compsoer features, and 4) potentially locks out future versions of Composer and might not play well with existing callbacks and algorithms.

In addition to reducing complexity, both options 1-3 would allow for more accurate loss logging.

I'd personally lean towards trying out option 1 for simplicity. Although 3 should also work well without data loss.

ohallstrom commented 1 month ago

There have been some very quick ablations run on this branch to investigate the impact of batch rampup, and given the results and a discussion among the team we have decided to not go forward with batch rampup for the large training runs.

This branch has served its purpose for now, and I will therefore close this PR.