Open XinDongol opened 2 months ago
@XinDongol do you mean microbatching or pipeline parallel?
@awgu - is there a context manager or similar option in fsdp2 that would support gradient accumulation and thus enable this in titan? I know we talked about this for HSDP but not sure about generic FSDP2.
I am guessing this is asking for normal microbatching. There are similar APIs for FSDP2 that can control communication during gradient accumulation.
We migrated the no_sync()
context to directly just module.set_requires_gradient_sync(bool)
so that it can be just placed at the top of the training loop as module.set_requires_gradient_sync(is_last_microbatch)
. Note however though, that typically for memory constrained cases, we prefer to just proceed as normal and reduce-scatter every microbatch.
Thanks for updating. @wanchaol Yes, I am talking about microbatching.
@awgu is it sufficient to change ? Thanks from (current)
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()
to
for microbatch_idx in range(microbatch):
batch = next(data_iterator)
input_ids, labels = batch
model.set_requires_gradient_sync(microbatch_idx==(microbatch-1))
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels) / microbatch
loss.backward()
@XinDongol I think that is sufficient.
If you want to avoid reduce-scatter in backward, then what you have is right. Note however that this will mean that gradients are left as unsharded through backward, which may use too much memory depending on the workload.
If you want to still reduce-scatter in backward, you can simply remove that model.set_requires_gradient_sync
line (effectively leaving it as the default of True
).
@XinDongol I think that is sufficient.
If you want to avoid reduce-scatter in backward, then what you have is right. Note however that this will mean that gradients are left as unsharded through backward, which may use too much memory depending on the workload.
If you want to still reduce-scatter in backward, you can simply remove that
model.set_requires_gradient_sync
line (effectively leaving it as the default ofTrue
).
Thanks @awgu @XinDongol. Very helpful discussion.
If model.set_requires_gradient_sync
is always set True
, is that equivalent to just do normal training without gradient accumulation? Like in below?
for microbatch_idx in range(microbatch):
batch = next(data_iterator)
input_ids, labels = batch
with loss_parallel_ctx():
pred = model(input_ids)
loss = loss_fn(pred, labels) / microbatch
loss.backward()
Is there a way to accumulate the gradient by keeping a running sum, and just do loss.backward()
after finishing all the microbatch?
Is there a way to accumulate the gradient by keeping a running sum, and just do loss.backward() after finishing all the microbatch?
What is the advantage of doing this? When you run a microbatch forward, the autograd graph associated with it (e.g. activations) will be kept alive until you run the corresponding microbatch backward. If you run all of your microbatch forward before a microbatch backward, then your memory cost will be similar to running the entire batch in one forward.
Gradient accumulation (micro step) could be very useful when we want to have large batch size but with limited number of gpus.