pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.29k stars 115 forks source link

[Feature] Add gradient accumulation #292

Open XinDongol opened 2 months ago

XinDongol commented 2 months ago

Gradient accumulation (micro step) could be very useful when we want to have large batch size but with limited number of gpus.

wanchaol commented 2 months ago

@XinDongol do you mean microbatching or pipeline parallel?

lessw2020 commented 2 months ago

@awgu - is there a context manager or similar option in fsdp2 that would support gradient accumulation and thus enable this in titan? I know we talked about this for HSDP but not sure about generic FSDP2.

awgu commented 2 months ago

I am guessing this is asking for normal microbatching. There are similar APIs for FSDP2 that can control communication during gradient accumulation.

We migrated the no_sync() context to directly just module.set_requires_gradient_sync(bool) so that it can be just placed at the top of the training loop as module.set_requires_gradient_sync(is_last_microbatch). Note however though, that typically for memory constrained cases, we prefer to just proceed as normal and reduce-scatter every microbatch.

XinDongol commented 2 months ago

Thanks for updating. @wanchaol Yes, I am talking about microbatching.

https://github.com/pytorch/torchtitan/blob/58b11693507bc16e7df4618455ebe66e8094f71d/train.py#L291-L294

@awgu is it sufficient to change ? Thanks from (current)

with loss_parallel_ctx():
    pred = model(input_ids)
    loss = loss_fn(pred, labels)
    loss.backward()

to

for microbatch_idx in range(microbatch):
    batch = next(data_iterator)
    input_ids, labels = batch
    model.set_requires_gradient_sync(microbatch_idx==(microbatch-1))
    with loss_parallel_ctx():
        pred = model(input_ids)
        loss = loss_fn(pred, labels) / microbatch
        loss.backward()
awgu commented 2 months ago

@XinDongol I think that is sufficient.

If you want to avoid reduce-scatter in backward, then what you have is right. Note however that this will mean that gradients are left as unsharded through backward, which may use too much memory depending on the workload.

If you want to still reduce-scatter in backward, you can simply remove that model.set_requires_gradient_sync line (effectively leaving it as the default of True).

dreasysnail commented 1 week ago

@XinDongol I think that is sufficient.

If you want to avoid reduce-scatter in backward, then what you have is right. Note however that this will mean that gradients are left as unsharded through backward, which may use too much memory depending on the workload.

If you want to still reduce-scatter in backward, you can simply remove that model.set_requires_gradient_sync line (effectively leaving it as the default of True).

Thanks @awgu @XinDongol. Very helpful discussion. If model.set_requires_gradient_sync is always set True, is that equivalent to just do normal training without gradient accumulation? Like in below?

for microbatch_idx in range(microbatch):
    batch = next(data_iterator)
    input_ids, labels = batch
    with loss_parallel_ctx():
    pred = model(input_ids)
    loss = loss_fn(pred, labels) / microbatch
    loss.backward()

Is there a way to accumulate the gradient by keeping a running sum, and just do loss.backward() after finishing all the microbatch?

awgu commented 1 week ago

Is there a way to accumulate the gradient by keeping a running sum, and just do loss.backward() after finishing all the microbatch?

What is the advantage of doing this? When you run a microbatch forward, the autograd graph associated with it (e.g. activations) will be kept alive until you run the corresponding microbatch backward. If you run all of your microbatch forward before a microbatch backward, then your memory cost will be similar to running the entire batch in one forward.