pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.27k stars 421 forks source link

FSDP2 + QLoRA: `NF4` dispatch error #1072

Closed jeromeku closed 4 months ago

jeromeku commented 5 months ago

FSDP2 + QLoRA: NF4 Dispatch Error

Getting the following error when trying to run FSDP2 + QLoRA distributed fine-tuning:

[rank0]: NotImplementedError: NF4Tensor dispatch: attempting to run _c10d_functional.all_gather_into_tensor.default, this is not supported

Ran with the following command:

 tune run --nnodes=1 --nproc_per_node=2 lora_finetune_fsdp2 \
--config 7B_qlora_fsdp2.yaml

where 7B_qlora_fsdp2.yaml is a local copy of recipes/dev/7B_qlora_fsdp2.yaml

torch and torchao are both nightly installs:

fully_shard wrapped units

When the model is wrapped with fully_shard here, the lora layers and TransformerDecoderLayer are individually wrapped, then the entire model (TransformerDecoder) is finally wrapped.

Stepping through with a debugger, we can see that the top-level FSDP unit manages whatever has not been wrapped: the tok_embeddings, RMSNorm and output projection (same shape as the embeddings but not tied).

Does this mean that these parameters will be all-gathered together? That is, when the token embeddings are needed, the RMSNorm and output projection params will be gathered as well, even though they are at the start and end of the forward pass respectively and hence not needed at that time?

kartikayk commented 5 months ago

@weifengpy or @awgu should have the most intelligent response to the the question. My understanding is that all modules in the root should be gathered together.

@weifengpy might also be able to comment on the NF4 issue

awgu commented 5 months ago

@jeromeku regarding wrapped units, everything you said is correct! When you call fully_shard(module), all parameters in module.parameters() becomes one group, except those already assigned to nested fully_shard call. (Here is the code pointer if you are interested.)

Does this mean that these parameters will be all-gathered together? That is, when the token embeddings are needed, the RMSNorm and output projection params will be gathered as well, even though they are at the start and end of the forward pass respectively and hence not needed at that time?

One thing you could do is wrap the output projection separately so that it gets all-gathered at the end of forward. Note that the current FSDP1 and FSDP2 designs would not allow you to wrap norm and output together (to be all-gathered together), so at best, we can either wrap norm separately as well (incurring a small all-gather) or allowing norm to be grouped with the tok_embeddings together at the root module.

However, for FSDP2, we are looking into (or actually, pretty much planning to) allow FSDP(List[Module]), so that you could call fully_shard([model.norm, model.output]) to group the norm and output together.

For example: https://github.com/pytorch/torchtitan/pull/382 https://github.com/pytorch/pytorch/pull/127786

jeromeku commented 5 months ago

@awgu

Thanks! A couple follow-up questions:

Is there a reason not to do this, as it results in fewer unnecessary parameters being gathered? I ask because this is not my implementation but part of torchtune's FSDP2 script and am trying to understand the logic.

weifengpy commented 5 months ago

the dispatch error happens during saving state dict. It should be resolved next week. I can keep the group updated

appreciate your timely feedback @jeromeku. As discussed above, wrapping embeddings, norm, output projection are more optimal. production LoRA recipe has it: https://github.com/pytorch/torchtune/pull/865 . FSDP2 + QLoRA should follow up

awgu commented 5 months ago

@jeromeku These are good questions! Let me try to provide some more info.

Define module to be an FSDP module if we call fully_shard(module).

Aside: FSDP Module vs. FSDP Unit In the past, in our FSDP [paper](https://arxiv.org/abs/2304.11277), we called this an _FSDP unit_. It is the same thing, just that FSDP1 was an `nn.Module` wrapper, where a `FullyShardedDataParallel` module wrapped the original module (e.g. `TransformerBlock`), so technically an FSDP unit could be the `FullyShardedDataParallel` module itself.
Examples **Wrapping 1** ``` model = Transformer(...) for module in model.modules(): if isinstance(module, TransformerBlock): fully_shard(module) fully_shard(model) ``` The FSDP modules are the root `Transformer` module and each `TransformerBlock` module. Other modules like `Attention`, `Feedforward`, `RMSNorm`, `Linear`, etc. are _not_ FSDP modules. **Wrapping 2** ``` model = Transformer(...) for module in model.modules(): if isinstance(module, TransformerBlock): fully_shard(module) fully_shard(model.tok_embeddings) # add this fully_shard(model.output) # add this fully_shard(model) ``` Now, we have two additional FSDP modules: `model.tok_embeddings` and `model.output`.

In our current design, each FSDP module is 1:1 with 1 FSDP parameter group. All parameters in the FSDP module, excluding those in a nested FSDP module, comprise that group.

Examples For Wrapping 1 above, the FSDP parameter groups are as follows: - Each transformer block forms one parameter group. - The root `Transformer` forms one parameter group. It has all of the parameters that are are not part of a transformer block (e.g. the `tok_embeddings`, `norm`, and `output` parameters). For Wrapping 2, we have two additional FSDP parameter groups, moving parameters from the root module's group into these two new groups: - Each transformer block forms one parameter group. - The `tok_embeddings` forms one parameter group, consisting of the embedding weight. - The `output` forms one parameter group, consisting of the linear weight. - The root `Transformer` forms one parameter group, consisting of just the `norm` weight.

Note that if an FSDP module does not manage any parameters, then it would simply not have an FSDP parameter group. This happens when all parameters have been partitioned among nested FSDP modules, so the parent FSDP module has no parameters left to manage.

Examples ``` model = Transformer(...) for module in model.modules(): if isinstance(module, TransformerBlock): fully_shard(module) fully_shard(model.tok_embeddings) fully_shard(model.output) fully_shard(model.norm) fully_shard(model) # `model` does not have an FSDP parameter group! ``` (This assumes that there are no other parameters, e.g. no positional embedding weight. Hopefully the idea is still clear.)

Before an FSDP module's forward (i.e. its pre-forward), we all-gather its group's parameters, and after its forward (i.e. its post-forward), we free its group's parameters. (We abstract these two ops out as unshard and reshard, respectively.) More specifically, we free a group's parameters after forward only if reshard_after_forward=True for that module. As a special case, we always set the root module's reshard_after_forward value to False since its parameters would get all-gathered immediately for backward anyway (typically).

Finally, by nature of our design, we always want to apply fully_shard to the root module. The reason is that the FSDP implementation uses several shared data structures for achieving computation and communication overlap (e.g. CUDA streams and some ordering info). Having the root module allows us to share data structures among all FSDP modules within its module tree.


Now, let us see if we have enough info to answer your questions.

During the forward pass, when the forward of embeddings is called, the norm and output projections will also be gathered, and subsequently, when the forward of norm and output projections is called, the embeddings will be gathered as well since they are all wrapped in the same FSDP unit? These modules are at opposite ends of the computation graph, resulting in unnecessary parameter gathering.

If we wrap only each transformer block and then the root Transformer, then what happens is that upon the root's forward, the tok_embeddings, norm, and output parameters all get all-gathered in the same NCCL all-gather kernel. Since this is the root which specially has reshard_after_forward clamped to False, the parameters would not get freed until the end of backward. This is indeed wasteful from a memory perspective.

Wrapping the tok_embeddings, norm, and output each separately would be more memory efficient (or similarly the wrapping in https://github.com/pytorch/torchtitan/pull/382 that puts norm and output together, which requires not-yet-landed code) since then the tok_embeddings would only be all-gathered near the beginning of forward and end of backward and the norm / output would only be all-gathered near the end of backward and beginning of forward (so probably want to just use reshard_after_forward=False for norm and output).

The alternative, per your suggestion is to either

  • individually wrap each of these modules OR
  • wrap the embeddings then call fully_shard on the entire model which only has norm and output projections not already wrapped. Is there a reason not to do this, as it results in fewer unnecessary parameters being gathered?

For the small # of GPUs case, there is probably not much reason to not just individually wrap these modules (so change the torchtune recipe's approach). However, just for context, doing that may not always be performant at larger # of GPUs; it depends on some other factors.

Communication cost can be modeled as (latency cost) + (bandwidth cost). For example, in forward, FSDP needs to all-gather the entire model in total.

Thus, we see two considerations:

Thus, it really is a balancing game, and the precise best solution really depends on your setup. Wrapping each transformer block and the root is a good starting point. Additionally wrapping tok_embeddings, norm, and output can help and be evaluated.

Finally, is the general pattern to always call fully_shard on the entire model after specific modules are wrapped to take care of any remaining modules that aren't already part of a FSDP unit? Even though it makes for simpler API, this could result in "fragmented" units that contain modules that aren't contiguous in the computation graph, per the embeddings / output projection example above.

Like mentioned above, calling fully_shard on the root is needed to share the data structures for comm./comp. overlap. I totally agree that it can create that fragmentation, so actually the optimal wrapping would be a 2-level hierarchy where the root is in the top level and there is a flat sequence in the second level that totally partitions all parameters (so the root manages no parameters). In practice, this is not so easy with the existing constraint of fully_shard(module: nn.Module), but if we land fully_shard(List[nn.Module]), then it should be much more doable.


After typing all of this up, I recognize this info may be useful to be accessible more broadly. @jeromeku @kartikayk I am happy to see where we may be able to host this info in a more accessible place, whether that is in PyTorch docs (note fully_shard does not have public docs rendered yet), torchtitan, torchtune, etc.

jeromeku commented 5 months ago

@awgu

Many thanks for the wonderful, well-written response -- clears things up!

As for documentation, that would be helpful. I'd been digging through torch.distributed._composable and though the classes / functions have docstrings, couldn't find a rendered version on the official pytorch documentation site.

Perhaps torchtitan would be an easy (temporary) landing place? There is already an FSDP note under docs. Longer term, I'd imagine a it'd make sense to include it under the Developer Notes section of the official docs.

weifengpy commented 4 months ago

landed the fix in trunk https://github.com/pytorch/torchtune/pull/1077. Now we support saving state dict per epoch. We can save model for inference or resume training from checkpoints

@jeromeku let me know if the NF4 dispatch error is gone after rebaseing to latest trunk

jeromeku commented 4 months ago

@weifengpy

Thanks! Works now.