young-geng / EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Apache License 2.0
2.33k stars 247 forks source link

How does `accumulate_gradient_steps` work? #108

Open VictorSanh opened 4 months ago

VictorSanh commented 4 months ago

Hi,

I am unsure I understand the logic behind accumulate_gradient_steps. I have these 3 configurations:

batch_size=1, accumulate_gradient_steps=1 -> blue
batch_size=2, accumulate_gradient_steps=1 -> red
batch_size=2, accumulate_gradient_steps=2 -> green

My initial understanding is that when doing grad accumulation, accumulate_gradient_steps forwards + backward steps and then the optimizer takes a step.

1/ I don't see such where that logic is handled in llama_train.py. it looks like in the method train_step, there is no counter for accumulate_gradient_steps and optimizer steps are taken after each forward? 2/ the logging is confusing: I would have expected the red and blue line to be overlapped, not the blue and green.

Screenshot 2024-03-08 at 3 45 57 PM

Is it possible that step is the counter of forward+backward operations and not the counter of (forward+backward) x grad_acc + optimizer_step?