google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.38k stars 247 forks source link

Question: Gradient Accumulation #607

Open thiagolaitz opened 3 months ago

thiagolaitz commented 3 months ago

Hello, does it support gradient accumulation or microbatches like those in the T5X repository? I didn't find a parameter for this in base.yml, maybe I just didn't see it? Thank you!

rwitten commented 3 months ago

We don't support that out of the box. We've found that tuning LR to be smaller is a better approach.

What is your use case?

thiagolaitz commented 3 months ago

I'm training bigger models than before, so I can't use the same batch size on the same TPU. Got any recommended ablation studies on using gradient accumulation versus lowering the LR? Also, if I skip gradient accumulation, should I just linearly reduce the LR based on the batch size? Thanks!

rodrigo-f-nogueira commented 2 months ago

+1 Adding another use case: considering that the availability of TPUs vary, we encounter situations where we initially train a model with a v4-128 TPU but later need to replicate the experiment with a v4-64 TPU, which has less memory. Thus, we must use gradient accumulation to maintain consistency in the results.

hxssgaa commented 2 months ago

Simply add following code after allocation of optimizer in optimizers.py support the gradient accumulation:

if config.accumulate_gradient_steps > 1:
    optimizer = optax.MultiSteps(optimizer, config.accumulate_gradient_steps)