pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.24k stars 413 forks source link

Gradient accumulation is not efficiently implemented for distributed recipes #1275

Open physicsrob opened 2 months ago

physicsrob commented 2 months ago

For distributed recipes, such as full_finetune_distributed, the gradients end up getting synchronized after each backward() pass instead of only once before the optimizer step. This results in significant unnecessary communication overhead.

I noticed this when porting a training script from using huggingface accelerate to something based on the full_finetune_distributed recipe. For my use case (fine tuning Llama3 on 8 nodes with 4xA100-80 each node) I noticed my training time more than doubled.

Digging in, I found this to be the cause. I've fixed this in my recipe, but it seems like the fix should be applied to all distributed recipes.

I'd be happy to contribute a PR if it's agreed that it should be fixed.

For reference, the fix is roughly wrapping the forward/backward with:

            with nullcontext() if do_sync else self._model.no_sync():

where do_sync is True for the last step of the gradient accumulation.

(There's an additional bug where gradient accumulation runs after a single forward/backward for the first step, but that's fairly minor)

joecummings commented 2 months ago

cc @weifengpy

This seems pretty straightforward, thanks for catching this @physicsrob. We'd love to accept a PR to fix this :)

weifengpy commented 2 months ago

look forward to the PR. I am actually curious which line of code triggered sync ?

kartikayk commented 2 months ago

Also cc: @awgu since we had a chat about gradient accumulation and FSDP in the early days of torchtune

awgu commented 2 months ago

For FSDP, there are two ways to accumulate gradients:

  1. Accumulate unsharded gradients (model.no_sync context in FSDP1, model.set_requires_gradient_sync(is_last_microbatch) in FSDP2)
  2. Accumulate sharded gradients

We should differentiate between training throughput and memory usage (both of which could be referred to by "efficiency").

The two options to accumulate gradients pose a direct tradeoff between communication time and memory usage.

Note that extra communication time mainly translates to lower throughput when it cannot be fully overlapped, which depends on your inter-node bandwidth.

For Llama3-8B, the unsharded gradients take ~8B numel. Whether option 1 vs. 2 makes sense depends on how much memory you can afford to tradeoff and whether you want to accumulate gradients in fp32 or bf16.

Note that with FSDP2, you can play around with this with some more granularity. For example, if you are "wrapping" / applying fully_shard to every transformer block, then you can selectively disable reduce-scatter for every other (or every k) transformer blocks since set_requires_gradient_sync(bool) is a method on FSDPModule (which each transformer block is when you call fully_shard on it). This can help overlap the fewer reduce-scatters.

ScottHoang commented 2 months ago

so in a no_sync context, gradients are accumulated in FP32?

awgu commented 2 months ago

I think in FSDP1, they are accumulated in the dtype that the gradients were computed in (e.g. bf16). In FSDP2, if you specify MixedPrecisionPolicy(reduce_dtype=torch.float32), then it will have extra logic to accumulate the gradients in fp32; if you did not specify that higher precision reduce_dtype, then it will similarly accumulate in the dtype that the gradients were computed in.

ScottHoang commented 2 months ago

Extending on this (and maybe unrelated to the overall topic) In the current implementation of FSDP1, we are sharding parameters across nodes in a multi-node scenario (zero3 implementation). Is there a way to limit the sharding to be intra-node?

awgu commented 2 months ago

You should be able to pass device_mesh as a 2D device mesh to enable HSDP. (You could also pass in a 2-tuple of process groups, but I think the checkpointing support is better for the device mesh path.)

You can create the device mesh via something like:

from torch.distributed import init_device_mesh

global_world_size = torch.distributed.get_world_size()  # global world size, assumes dist is initialized
intra_node_size = torch.cuda.device_count()  # e.g. 8
device_mesh = init_device_mesh("cuda", (global_world_size // intra_node_size, intra_node_size))  # (replicate, shard)

fsdp_model = FSDP(model, ..., device_mesh=device_mesh)
ScottHoang commented 2 months ago

So, in the case of a small model < 13B, and we want to scale it with multi nodes and increase throughput, it is better to use HSDP with device_mesh for intra-node fsdp and inter-node DDP ?

awgu commented 2 months ago

It depends on your inter-node bandwidth. If your inter-node bandwidth is fast, FSDP is probably still better, especially if your model is compute-dense like a transformer.

The overall workflow I would follow is to get a profiler trace, look at if the communications are overlapping or not, and determine from there. If you can still overlap FSDP collectives in the multi-node setup, then you probably prefer FSDP because then you can save more memory and possibly reinvest that into batch size or decrease the amount of activation checkpointing.

If you have slow inter-node bandwidth though, then it is possible that there is a really large discrete jump in communication time when going from single-node to multi-node, in which case HSDP can help because you only have some all-reduce across nodes.

ScottHoang commented 2 months ago

This perfectly solved my problem! thank you!

ScottHoang commented 2 months ago

@awgu Actually, one last question: in Hybrid shard mode with multi-nodes, does "sync_module_states" still broadcast rank=0's params to rank on different nodes?

awgu commented 2 months ago

@ScottHoang yes, it will broadcast from global rank 0 to all ranks (including both intra and inter-node process groups): https://github.com/pytorch/pytorch/blob/afaa5fcecb07472a8805902074f4611dc5798f76/torch/distributed/fsdp/_init_utils.py#L632-L635