Closed fabianlim closed 7 months ago
Hi! This solution does indeed make sense to me, let's start with a PR to accelerate and then the upstream to transformers? :)
Note: for the TrainingArguments
, we need to add this to the Accelerator config class instead and handle the logic that way as we are no longer adding more args to the TrainingArguments
when dealing with accelerate and instead handling it through the new config class
@muellerzr thanks for looking at the issue. I understand I will add the gradient_accumulation_force_sync
arg to AcceleratorConfig instead.
Will have an accelerate PR to review soon. :)
@muellerzr As discussed I have first begun to draft an accelerate PR .
While fixing the tests, I noticed that one of the old tests test_gradient_accumulation_with_opt_and_scheduler
was disabled for torch < 2.0. On further inspection the test was terribly broken (it was zeroing gradients before there were being checked)
In the PR i have raised, I have the test_gradient_accumulation_with_opt_and_scheduler
test somewhat, but in the check_model_parameters
i need to pass an rtol=1-3
to the torch.allclose
, see here. For the other test test_gradient_accumulation
the rtol
setting was not needed (the error was much smaller). If you want I can investigate closer why.
Finally I have yet to update the docs, if you have any pointers which documentation I should focus on, please let me know.
There seems to be a bug. If I set sync_each_batch=True, the optimizer will update the gradient every batch, even if I set gradient_accmulation_steps=4.
There seems to be a bug. If I set sync_each_batch=True, the optimizer will update the gradient every batch, even if I set gradient_accmulation_steps=4.
Thanks for reporting, but we have unit tests but maybe we overlooked something.
To help me understand better, Do you have a reproduction for what you are seeing?
Update: Also just to make sure you are not using CPU_Offload with FSDP and sync_each_batch=True
, it does not support grad accum, see here
Thanks for your clarification. Here is some part of my code:
grad_accumulate_plugin = GradientAccumulationPlugin(
num_steps=args.accumulate_grad_iters
)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
gradient_accumulation_plugin=grad_accumulate_plugin,
log_with=["tensorboard"],
project_config=project_config,
)
for elapse_iter, batch in enumerate(active_dataloader):
with accelerator.accumulate(model):
ret_dict, tb_dict = model(batch)
loss = ret_dict["loss"].mean()
loss_value = loss.item()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
I only use accelerator in my code and do not use Trainer (it seems to be ok). This code will work smoothly when sync_each_batch=False
, but occurs undesirable results when sync_each_batch=True
, i.e., the optimizer updates each step.
I think the sync_each_batch
flag would be better to put at here.
@Nightmare-n thanks for sharing the above code snippet and i see that you follow the gradient accum concept guide, but now im confused with what Accelerator.accumulate does so let me clarify with the maintainers first
@muellerzr in the grad accum concept guide, it does say that one can remove the total_batched_samples % args.gradient_accumulation_steps == 0
guard we typically do to prevent the optimizer.step
whilst in an accum batch. However
Accelerate.acummulate
implementation that implements this guard. It only implements when we aply the no_sync
context manager, which controls the frequency of gradient sync, but does not control when the optimizer steps. So did something change in the implementation? If what I said above is correct, then the concept guide is inaccurate, and then I have an explaination for @Nightmare-n 's observation.
@Nightmare-n were you trying with DDP or FSDP?
Yes, I am trying DDP and FSDP. I use AcceleratedOptimizer, and the step
function will check the sync_gradients
flag to determine whether the model weights should be updated (look at here).
@muellerzr @Nightmare-n Oh no my bad. I completely overlooked this. That means this PR as @Nightmare-n said is incorrect
I drafted out something quickly here https://github.com/huggingface/accelerate/pull/2790 but I havnt had time to test, let me try to find some time. @Nightmare-n gave a suggestion to fix it inside no_sync
, but I thought it violates the naming of the function, hence i retain the fixed in the same accumulate
function. Any comments are welcome.
Feature request
We propose a feature to allow:
_do_sync
to take aforce
boolean flag, where_do_sync(force=True)
forces a gradient sync.Trainer
/Accelerate
to appropriately pass theforce
flag if the user requests the gradients to sync during accmululation.During the main
_inner_training_loop
, thetraining_step
is run under acontextmanager
created byAccelerator.accumulate
.If we inspect the
contextmanager
, we notice thatAccelerator.accumulate
will return theno_sync
context wheneverself.sync_gradients == True
.On inspection
_do_sync
setsself.sync_gradients == True
only at the end of a gradient accumulation batch. NOTE:Trainer
setssync_with_dataloader = False
and this cannot be changed. Therefore the first clause will never execute.Hence we propose to allow the user to for force
_do_sync
to setself.gradient_state._set_sync_gradients(True)
.Motivation
Not syncing gradients can have adverse effects in distributed training. As it has been warned in
torch
, theno_sync
context manager for FSDP will incur additional memory requirements:Gradient accumulation in FSDP often results in OOM on large models with a moderate number of GPUs. This occurs because
Trainer
by default will activateno_sync
when using gradient accumulation, effectively disabling gradient synchronization to reduce communication across shards. However, this results in high memory usage because parameters and gradients are not resharded. We propose a solution that avoids OOM by allowing the user to enable synchronization of parameters and gradients on all (or some) of the data batches when using gradient accumulation.Setting:
In the table below, we see Mixtral (47B parameters) and CodeLlama (34B parameters) will OOM on 8 A100-80GB when using gradient accumulation. However when we enable synchronization (i.e. disable
no_sync
), then there is no noticeable increase in gpu memory consumption when using gradient accumulation.no_sync
Your contribution
We can help contribute PRs into
transformers
andaccelerate
to effect these changes. We propose to do the following in thetransformer
andaccelerate
packages.Accelerate Repository:
add additional control in
GradientAccumulationPlugin
force
into_do_sync
.Transformers Repository
TrainingArguments
:create_accelerator_and_postprocess
to configureGradientAccumulationPlugin
:Documentation