huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.01k stars 26.79k forks source link

Step shifting using total_batched_samples for gradient_accumulation_steps counting #33671

Open kibitzing opened 1 month ago

kibitzing commented 1 month ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

Issue Analysis

The on_step_begin callback is invoked when the step is divisible by args.gradient_accumulation_steps (i.e.,step % args.gradient_accumulation_steps == 0). However, the on_step_end callback behaves differently. Its condition is as follows: total_batched_samples % args.gradient_accumulation_steps == 0 or is_last_step_and_steps_less_than_grad_acc

Here, the on_step_end callback is triggered when total_batched_samples is divisible by args.gradient_accumulation_steps. It’s important to note that step is reset at the beginning of each epoch, whereas total_batched_samples is initialized to 0 at the start of training and persists across all epochs until training ends.

Expected Behavior:

When gradient_accumulation_steps = N, there should be exactly N sub-steps between the on_step_begin and on_step_end callbacks. This ensures that gradients are accumulated correctly before an optimization step occurs. The only exception to this rule is the last step in an epoch or the training run, where fewer sub-steps might exist.

Problematic Behavior Example

The issue arises when total_batched_samples is not divisible by args.gradient_accumulation_steps. For example, if steps_per_epoch = 3 and gradient_accumulation_steps = 2, we observe the following behavior:

Epoch 1:

Epoch 2:

Epoch 3:

Note: total_batched_samples is incremented by 1 at the start of each step loop.

In this case, when the number of steps per epoch is not divisible by gradient_accumulation_steps, the callbacks only function correctly at intervals, leading to incorrect behavior during other epochs.

kibitzing commented 1 month ago

Upon further investigation, I found that this issue is not solely related to the callbacks. When performing gradient accumulation for updates, if the total number of batches in an epoch is not divisible by gradient_accumulation_steps, a shifting phenomenon occurs with gradient accumulation.

Specifically, after updating with the last non-divisible batch, the gradient accumulation should be counted from 0. However, since total_batched_samples is not refreshed with each epoch, it can lead to this shifting issue.

For example, if the number of data batches_per_epoch = 7 and gradient_accumulation_steps = 4, the updates would proceed as follows:

LysandreJik commented 1 month ago

cc @SunMarc and @muellerzr

SunMarc commented 1 month ago

Thanks for this clear report @kibitzing ! I left a comment on your PR

github-actions[bot] commented 2 days ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.