Open physicsrob opened 2 months ago
cc @weifengpy
This seems pretty straightforward, thanks for catching this @physicsrob. We'd love to accept a PR to fix this :)
look forward to the PR. I am actually curious which line of code triggered sync ?
Also cc: @awgu since we had a chat about gradient accumulation and FSDP in the early days of torchtune
For FSDP, there are two ways to accumulate gradients:
model.no_sync
context in FSDP1, model.set_requires_gradient_sync(is_last_microbatch)
in FSDP2)We should differentiate between training throughput and memory usage (both of which could be referred to by "efficiency").
The two options to accumulate gradients pose a direct tradeoff between communication time and memory usage.
Note that extra communication time mainly translates to lower throughput when it cannot be fully overlapped, which depends on your inter-node bandwidth.
For Llama3-8B, the unsharded gradients take ~8B numel. Whether option 1 vs. 2 makes sense depends on how much memory you can afford to tradeoff and whether you want to accumulate gradients in fp32 or bf16.
Note that with FSDP2, you can play around with this with some more granularity. For example, if you are "wrapping" / applying fully_shard
to every transformer block, then you can selectively disable reduce-scatter for every other (or every k
) transformer blocks since set_requires_gradient_sync(bool)
is a method on FSDPModule
(which each transformer block is when you call fully_shard
on it). This can help overlap the fewer reduce-scatters.
so in a no_sync context, gradients are accumulated in FP32?
I think in FSDP1, they are accumulated in the dtype that the gradients were computed in (e.g. bf16).
In FSDP2, if you specify MixedPrecisionPolicy(reduce_dtype=torch.float32)
, then it will have extra logic to accumulate the gradients in fp32; if you did not specify that higher precision reduce_dtype
, then it will similarly accumulate in the dtype that the gradients were computed in.
Extending on this (and maybe unrelated to the overall topic) In the current implementation of FSDP1, we are sharding parameters across nodes in a multi-node scenario (zero3 implementation). Is there a way to limit the sharding to be intra-node?
You should be able to pass device_mesh
as a 2D device mesh to enable HSDP. (You could also pass in a 2-tuple of process groups, but I think the checkpointing support is better for the device mesh path.)
You can create the device mesh via something like:
from torch.distributed import init_device_mesh
global_world_size = torch.distributed.get_world_size() # global world size, assumes dist is initialized
intra_node_size = torch.cuda.device_count() # e.g. 8
device_mesh = init_device_mesh("cuda", (global_world_size // intra_node_size, intra_node_size)) # (replicate, shard)
fsdp_model = FSDP(model, ..., device_mesh=device_mesh)
So, in the case of a small model < 13B, and we want to scale it with multi nodes and increase throughput, it is better to use HSDP with device_mesh for intra-node fsdp and inter-node DDP ?
It depends on your inter-node bandwidth. If your inter-node bandwidth is fast, FSDP is probably still better, especially if your model is compute-dense like a transformer.
The overall workflow I would follow is to get a profiler trace, look at if the communications are overlapping or not, and determine from there. If you can still overlap FSDP collectives in the multi-node setup, then you probably prefer FSDP because then you can save more memory and possibly reinvest that into batch size or decrease the amount of activation checkpointing.
If you have slow inter-node bandwidth though, then it is possible that there is a really large discrete jump in communication time when going from single-node to multi-node, in which case HSDP can help because you only have some all-reduce across nodes.
This perfectly solved my problem! thank you!
@awgu Actually, one last question: in Hybrid shard mode with multi-nodes, does "sync_module_states" still broadcast rank=0's params to rank on different nodes?
@ScottHoang yes, it will broadcast from global rank 0 to all ranks (including both intra and inter-node process groups): https://github.com/pytorch/pytorch/blob/afaa5fcecb07472a8805902074f4611dc5798f76/torch/distributed/fsdp/_init_utils.py#L632-L635
For distributed recipes, such as full_finetune_distributed, the gradients end up getting synchronized after each backward() pass instead of only once before the optimizer step. This results in significant unnecessary communication overhead.
I noticed this when porting a training script from using huggingface accelerate to something based on the full_finetune_distributed recipe. For my use case (fine tuning Llama3 on 8 nodes with 4xA100-80 each node) I noticed my training time more than doubled.
Digging in, I found this to be the cause. I've fixed this in my recipe, but it seems like the fix should be applied to all distributed recipes.
I'd be happy to contribute a PR if it's agreed that it should be fixed.
For reference, the fix is roughly wrapping the forward/backward with:
where do_sync is True for the last step of the gradient accumulation.
(There's an additional bug where gradient accumulation runs after a single forward/backward for the first step, but that's fairly minor)