huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
16.18k stars 1.59k forks source link

RuntimeError when combining FSDP and disable_adapter #1442

Closed wanghao14 closed 6 months ago

wanghao14 commented 8 months ago

System Info

peft: 0.7.1; torch: 2.3.0.dev20240128+cu121; accelerate: 0.26.1; transformers: 4.37.2; Python: 3.10.12 Using the Pytorch container 23.12 provided by Nvidia.

image

The hardware environment contains four A100-40G graphics cards.

Who can help?

@pacman100 @younesbelkada

Information

Tasks

Reproduction

Hi, I want to use both FSDP and peft in my project, and I insert Lora to the pretrained LLM by peft.get_peft_model and then wrap the whole model using torch.distributed.fsdp.FullyShardedDataParallel. The only trainable part of the model is the Lora adapter. Additionally, I need to call the original model by with my_model.disable_adapter():. When running the whole code, I encounter following error(intercepted relevant parts):

File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 853, in forward output = self._fsdp_wrapped_module(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "/data4/Projects/CoCap/mm_video/experiments/context_compression/modeling/in_context_autoencoder.py", line 373, in forward with self.base_llm.disable_adapter(): File "/usr/lib/python3.10/contextlib.py", line 135, in enter return next(self.gen) File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 567, in disable_adapter self.base_model.disable_adapter_layers() File "/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/model.py", line 403, in disable_adapter_layers self._set_adapter_layers(enabled=False) File "/usr/local/lib/python3.10/dist-packages/peft/tuners/lora/model.py", line 381, in _set_adapter_layers module.enable_adapters(enabled) File "/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py", line 403, in enable_adapters layer.requiresgrad(False) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2435, in requiresgrad p.requiresgrad(requires_grad) RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().

Expected behavior

Using with my_model.disable_adapter(): to call the original model, even though it is wrapped by FSDP.

BenjaminBossan commented 8 months ago

Hmm, that's unfortunate. If it's possible memory-wise, could you please test if the same happens without FSDP? Maybe you have a smaller model available that you could test it with.

As a potential workaround, could you please try if this works:

from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import ModulesToSaveWrapper

# to disable adapters
for module in model.modules():
    if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
        module._disable_adapters = True

output = model(...)

# to re-enable them, assuming there is only one adapter
for module in model.modules():
    if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
        module._disable_adapters = False
wanghao14 commented 8 months ago

Apologies for the delayed response.

  1. After conducting tests, I found that combining DDP with my_model.disable_adapter() yielded successful results, indicating that the issue may have stemmed from the introduction of FSDP."
  2. I've adopted your suggested code to my project and can confirm that the model can be trained. However, it's worth noting that comparing to DDP, the performance appears to be significantly lower at the beginning of training and GPU occupancy is high.
github-actions[bot] commented 7 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.