Open kibitzing opened 1 month ago
Upon further investigation, I found that this issue is not solely related to the callbacks. When performing gradient accumulation for updates, if the total number of batches in an epoch is not divisible by gradient_accumulation_steps
, a shifting phenomenon occurs with gradient accumulation.
Specifically, after updating with the last non-divisible batch, the gradient accumulation should be counted from 0. However, since total_batched_samples is not refreshed with each epoch, it can lead to this shifting issue.
For example, if the number of data batches_per_epoch = 7
and gradient_accumulation_steps = 4
, the updates would proceed as follows:
Epoch 1
Epoch 2
Epoch 3
cc @SunMarc and @muellerzr
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
transformers
version: 4.39.0Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Issue Analysis
The
on_step_begin
callback is invoked when thestep
is divisible byargs.gradient_accumulation_steps
(i.e.,step % args.gradient_accumulation_steps == 0
). However, theon_step_end
callback behaves differently. Its condition is as follows:total_batched_samples % args.gradient_accumulation_steps == 0 or is_last_step_and_steps_less_than_grad_acc
Here, the
on_step_end
callback is triggered whentotal_batched_samples
is divisible byargs.gradient_accumulation_steps
. It’s important to note thatstep
is reset at the beginning of each epoch, whereastotal_batched_samples
is initialized to 0 at the start of training and persists across all epochs until training ends.Expected Behavior:
When
gradient_accumulation_steps = N
, there should be exactly N sub-steps between theon_step_begin
andon_step_end
callbacks. This ensures that gradients are accumulated correctly before an optimization step occurs. The only exception to this rule is the last step in an epoch or the training run, where fewer sub-steps might exist.Problematic Behavior Example
The issue arises when
total_batched_samples
is not divisible by args.gradient_accumulation_steps. For example, ifsteps_per_epoch = 3
andgradient_accumulation_steps = 2
, we observe the following behavior:Epoch 1:
on_step_begin
called,on_step_end
not called (expected behavior)on_step_begin
not called,on_step_end
called (expected behavior)on_step_begin
called,on_step_end
not called (expected behavior)Epoch 2:
on_step_begin
called (0 % 2 == 0),on_step_end
called (4 % 2 == 0) (incorrect because on_step_end is called after only one sub step)on_step_begin
not called (1 % 2 != 0),on_step_end
not called (5 % 2 != 0) (incorrect)on_step_begin
called (2 % 2 == 0),on_step_end
called (6 % 2 == 0) (incorrect because on_step_end is called after only one sub step)Epoch 3:
on_step_begin
called,on_step_end
not called (expected behavior)on_step_begin
not called,on_step_end
called (expected behavior)on_step_begin
called,on_step_end
not called (expected behavior)Note:
total_batched_samples
is incremented by 1 at the start of each step loop.In this case, when the number of steps per epoch is not divisible by gradient_accumulation_steps, the callbacks only function correctly at intervals, leading to incorrect behavior during other epochs.