Closed jeromeku closed 4 months ago
@weifengpy or @awgu should have the most intelligent response to the the question. My understanding is that all modules in the root should be gathered together.
@weifengpy might also be able to comment on the NF4 issue
@jeromeku regarding wrapped units, everything you said is correct! When you call fully_shard(module)
, all parameters in module.parameters()
becomes one group, except those already assigned to nested fully_shard
call. (Here is the code pointer if you are interested.)
Does this mean that these parameters will be all-gathered together? That is, when the token embeddings are needed, the RMSNorm and output projection params will be gathered as well, even though they are at the start and end of the forward pass respectively and hence not needed at that time?
One thing you could do is wrap the output projection separately so that it gets all-gathered at the end of forward. Note that the current FSDP1 and FSDP2 designs would not allow you to wrap norm
and output
together (to be all-gathered together), so at best, we can either wrap norm
separately as well (incurring a small all-gather) or allowing norm
to be grouped with the tok_embeddings
together at the root module.
However, for FSDP2, we are looking into (or actually, pretty much planning to) allow FSDP(List[Module])
, so that you could call fully_shard([model.norm, model.output])
to group the norm
and output
together.
For example: https://github.com/pytorch/torchtitan/pull/382 https://github.com/pytorch/pytorch/pull/127786
@awgu
Thanks! A couple follow-up questions:
During the forward pass, when the forward
of embeddings
is called, the norm
and output projections
will also be gathered, and subsequently, when the forward
of norm
and output projections
is called, the embeddings
will be gathered as well since they are all wrapped in the same FSDP
unit? These modules are at opposite ends of the computation graph, resulting in unnecessary parameter gathering.
The alternative, per your suggestion is to either
1) individually wrap each of these modules
OR
2) wrap the embeddings
then call fully_shard
on the entire model which only has norm
and output projections
not already wrapped.
Is there a reason not to do this, as it results in fewer unnecessary parameters being gathered? I ask because this is not my implementation but part of torchtune
's FSDP2
script and am trying to understand the logic.
fully_shard
on the entire model after specific modules are wrapped to take care of any remaining modules that aren't already part of a FSDP
unit? Even though it makes for simpler API, this could result in "fragmented" units that contain modules that aren't contiguous in the computation graph, per the embeddings
/ output projection
example above. the dispatch error happens during saving state dict. It should be resolved next week. I can keep the group updated
appreciate your timely feedback @jeromeku. As discussed above, wrapping embeddings
, norm
, output projection
are more optimal. production LoRA recipe has it: https://github.com/pytorch/torchtune/pull/865 . FSDP2 + QLoRA should follow up
@jeromeku These are good questions! Let me try to provide some more info.
Define module
to be an FSDP module if we call fully_shard(module)
.
In our current design, each FSDP module is 1:1 with 1 FSDP parameter group. All parameters in the FSDP module, excluding those in a nested FSDP module, comprise that group.
Note that if an FSDP module does not manage any parameters, then it would simply not have an FSDP parameter group. This happens when all parameters have been partitioned among nested FSDP modules, so the parent FSDP module has no parameters left to manage.
Before an FSDP module's forward (i.e. its pre-forward), we all-gather its group's parameters, and after its forward (i.e. its post-forward), we free its group's parameters. (We abstract these two ops out as unshard
and reshard
, respectively.) More specifically, we free a group's parameters after forward only if reshard_after_forward=True
for that module. As a special case, we always set the root module's reshard_after_forward
value to False
since its parameters would get all-gathered immediately for backward anyway (typically).
Finally, by nature of our design, we always want to apply fully_shard
to the root module. The reason is that the FSDP implementation uses several shared data structures for achieving computation and communication overlap (e.g. CUDA streams and some ordering info). Having the root module allows us to share data structures among all FSDP modules within its module tree.
Now, let us see if we have enough info to answer your questions.
During the forward pass, when the forward of embeddings is called, the norm and output projections will also be gathered, and subsequently, when the forward of norm and output projections is called, the embeddings will be gathered as well since they are all wrapped in the same FSDP unit? These modules are at opposite ends of the computation graph, resulting in unnecessary parameter gathering.
If we wrap only each transformer block and then the root Transformer
, then what happens is that upon the root's forward
, the tok_embeddings
, norm
, and output
parameters all get all-gathered in the same NCCL all-gather kernel. Since this is the root which specially has reshard_after_forward
clamped to False
, the parameters would not get freed until the end of backward. This is indeed wasteful from a memory perspective.
Wrapping the tok_embeddings
, norm
, and output
each separately would be more memory efficient (or similarly the wrapping in https://github.com/pytorch/torchtitan/pull/382 that puts norm
and output
together, which requires not-yet-landed code) since then the tok_embeddings
would only be all-gathered near the beginning of forward and end of backward and the norm
/ output
would only be all-gathered near the end of backward and beginning of forward (so probably want to just use reshard_after_forward=False
for norm
and output
).
The alternative, per your suggestion is to either
- individually wrap each of these modules OR
- wrap the embeddings then call fully_shard on the entire model which only has norm and output projections not already wrapped. Is there a reason not to do this, as it results in fewer unnecessary parameters being gathered?
For the small # of GPUs case, there is probably not much reason to not just individually wrap these modules (so change the torchtune recipe's approach). However, just for context, doing that may not always be performant at larger # of GPUs; it depends on some other factors.
Communication cost can be modeled as (latency cost) + (bandwidth cost). For example, in forward, FSDP needs to all-gather the entire model in total.
N
GPUs.O(N)
) since the message is sent with N-1
hops in a ring around N
GPUs.Thus, we see two considerations:
N
), then the communication latency cost increases and become more dominant. To reduce the latency cost, you must issue fewer collectives. You cannot issue too few or else you use too much memory or cannot overlap, but if you issue too many small collectives, you will pay the latency cost many more times.norm
alone) achieves low bandwidth utilization. To reduce bandwidth cost, you prefer to issue larger collectives.Thus, it really is a balancing game, and the precise best solution really depends on your setup. Wrapping each transformer block and the root is a good starting point. Additionally wrapping tok_embeddings
, norm
, and output
can help and be evaluated.
Finally, is the general pattern to always call fully_shard on the entire model after specific modules are wrapped to take care of any remaining modules that aren't already part of a FSDP unit? Even though it makes for simpler API, this could result in "fragmented" units that contain modules that aren't contiguous in the computation graph, per the embeddings / output projection example above.
Like mentioned above, calling fully_shard
on the root is needed to share the data structures for comm./comp. overlap. I totally agree that it can create that fragmentation, so actually the optimal wrapping would be a 2-level hierarchy where the root is in the top level and there is a flat sequence in the second level that totally partitions all parameters (so the root manages no parameters). In practice, this is not so easy with the existing constraint of fully_shard(module: nn.Module)
, but if we land fully_shard(List[nn.Module])
, then it should be much more doable.
After typing all of this up, I recognize this info may be useful to be accessible more broadly. @jeromeku @kartikayk I am happy to see where we may be able to host this info in a more accessible place, whether that is in PyTorch docs (note fully_shard
does not have public docs rendered yet), torchtitan, torchtune, etc.
@awgu
Many thanks for the wonderful, well-written response -- clears things up!
As for documentation, that would be helpful. I'd been digging through torch.distributed._composable and though the classes / functions have docstrings, couldn't find a rendered version on the official pytorch documentation site.
Perhaps torchtitan
would be an easy (temporary) landing place? There is already an FSDP note under docs. Longer term, I'd imagine a it'd make sense to include it under the Developer Notes
section of the official docs.
landed the fix in trunk https://github.com/pytorch/torchtune/pull/1077. Now we support saving state dict per epoch. We can save model for inference or resume training from checkpoints
@jeromeku let me know if the NF4 dispatch error is gone after rebaseing to latest trunk
@weifengpy
Thanks! Works now.
FSDP2 + QLoRA:
NF4
Dispatch ErrorGetting the following error when trying to run FSDP2 + QLoRA distributed fine-tuning:
Ran with the following command:
where
7B_qlora_fsdp2.yaml
is a local copy ofrecipes/dev/7B_qlora_fsdp2.yaml
torch
andtorchao
are both nightly installs:fully_shard
wrapped unitsWhen the model is wrapped with
fully_shard
here, thelora
layers andTransformerDecoderLayer
are individually wrapped, then the entire model (TransformerDecoder
) is finally wrapped.Stepping through with a debugger, we can see that the top-level
FSDP
unit manages whatever has not been wrapped: thetok_embeddings
,RMSNorm
andoutput projection
(same shape as theembeddings
but not tied).Does this mean that these parameters will be
all-gathered
together? That is, when the token embeddings are needed, theRMSNorm
andoutput projection
params will be gathered as well, even though they are at the start and end of the forward pass respectively and hence not needed at that time?