google-research / tuning_playbook

A playbook for systematically maximizing the performance of deep learning models.
Other
26.64k stars 2.21k forks source link

Why avoid gradient accumulation? #69

Open RonanKMcGovern opened 4 months ago

RonanKMcGovern commented 4 months ago

There is this quote:


**Gradient accumulation** simulates a larger batch size than the
--
252 | hardware can support and therefore does not provide any throughput
253 | benefits. It should generally be avoided in applied work.

For large GPUs and multi-GPU setups, I can see this making sense, as you can run batches of 32 and don't need accumulation.

Am I mistaken or missing something?

But, on smaller GPUs, grad accum can be important because it provides averaging in the virtual batches that stabilises the training.

DimitrisMantas commented 3 months ago

A lot of architecture have BN layers which don't work properly unless actually backprogated through, I think.

RonanKMcGovern commented 3 months ago

Interesting? What is BN? Bias?

When you say "a lot of", does that include Llama 2 and 3 type models?

Basically you're saying that accumulating the gradients isn't enough, some important info is thrown away once you move to the next forward pass?

On Tue, Jun 11, 2024 at 4:44 PM Dimitris Mantas @.***> wrote:

A lot of architecture have BN layers which don't work properly unless actually backprogated through, I think.

— Reply to this email directly, view it on GitHub https://github.com/google-research/tuning_playbook/issues/69#issuecomment-2161574060, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASVG6CXHG4UBKV6Z5HC4L2DZG5OSVAVCNFSM6AAAAABIRBGLXCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNRRGU3TIMBWGA . You are receiving this because you authored the thread.Message ID: @.***>

DimitrisMantas commented 3 months ago

Batch normalization. Essentially, BN blocks keep track of the running batch mean and standard deviation and use them to normalize their inputs.

These parameters are non-trainable and are updated with each minibatch the blocks receive. However, because the total number of batches per epoch is not the same as that of backpropagations when using gradient accumulation, BN blocks now compute "incorrect" statistics. This problem is further magnified by their other parameters still being updated according to accumulated batches. Basically, batches and their descriptive statistics become “unsynchronized”.

BN blocks are very popular in computer vision tasks, and unfortunately, I’m not too familiar with much else. However, I believe that transformer blocks use typically use layer normalization blocks which do not depend on batch size, so you should be safe.

DimitrisMantas commented 3 months ago

By the way, large batch sizes are just as "dangerous" as small ones due to potential overmoothing of the gradient landscape. It's kind of a "pick your poison" situation.

RonanKMcGovern commented 3 months ago

Thanks yeah agreed on the problems at big batches.

And yeah that makes sense re ga and batch norm. Llama 2 and 3 are layer norm so should be fine but good to know for multi model models - I need to check if clip has batch norm.

On Tue 11 Jun 2024 at 18:31, Dimitris Mantas @.***> wrote:

By the way, large batch sizes are just as "dangerous" as small ones due to potential overmoothing of the gradient landscape. It's kind of a "pick your poison" situation.

— Reply to this email directly, view it on GitHub https://github.com/google-research/tuning_playbook/issues/69#issuecomment-2161708721, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASVG6CUBGN5YHYUUFXQHODDZG53CZAVCNFSM6AAAAABIRBGLXCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNRRG4YDQNZSGE . You are receiving this because you authored the thread.Message ID: @.***>