huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
Apache License 2.0
7.34k stars 875 forks source link

Accelerate 0.30.0 Breaks FSDP QLora #2761

Closed mallorbc closed 1 day ago

mallorbc commented 1 month ago

System Info

See below a pip list output that does not work:

Package                  Version
------------------------ ---------------
accelerate               0.30.0
aiohttp                  3.9.5
aiosignal                1.3.1
annotated-types          0.6.0
async-timeout            4.0.3
attrs                    23.2.0
bitsandbytes             0.43.1
certifi                  2024.2.2
charset-normalizer       3.3.2
click                    8.1.7
datasets                 2.19.1
deepspeed                0.14.2+5f631abc
dill                     0.3.8
docker-pycreds           0.4.0
docstring_parser         0.16
einops                   0.8.0
eval_type_backport       0.2.0
exceptiongroup           1.2.1
filelock                 3.14.0
flash-attn               2.5.8
frozenlist               1.4.1
fsspec                   2024.3.1
gitdb                    4.0.11
GitPython                3.1.43
hf_transfer              0.1.6
hjson                    3.1.0
huggingface-hub          0.23.0
idna                     3.7
iniconfig                2.0.0
Jinja2                   3.1.4
markdown-it-py           3.0.0
MarkupSafe               2.1.5
mdurl                    0.1.2
mpmath                   1.3.0
multidict                6.0.5
multiprocess             0.70.16
networkx                 3.1
numpy                    1.24.4
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.4.127
nvidia-nvtx-cu12         12.1.105
packaging                24.0
pandas                   2.0.3
peft                     0.10.0
pillow                   10.3.0
pip                      24.0
platformdirs             4.2.1
pluggy                   1.5.0
protobuf                 3.20.1
psutil                   5.9.8
py-cpuinfo               9.0.0
pyarrow                  16.0.0
pyarrow-hotfix           0.6
pydantic                 2.7.1
pydantic_core            2.18.2
Pygments                 2.18.0
pynvml                   11.5.0
pytest                   8.2.0
python-dateutil          2.9.0.post0
pytz                     2024.1
PyYAML                   6.0.1
regex                    2024.5.10
requests                 2.31.0
rich                     13.7.1
safetensors              0.4.3
scipy                    1.10.1
sentencepiece            0.2.0
sentry-sdk               2.1.1
setproctitle             1.3.3
setuptools               69.5.1
shtab                    1.7.1
six                      1.16.0
smmap                    5.0.1
sympy                    1.12
text-generation          0.7.0
tokenizers               0.19.1
tomli                    2.0.1
torch                    2.3.0
torchaudio               2.3.0
torchvision              0.18.0
tqdm                     4.66.4
transformers             4.40.2
triton                   2.3.0
trl                      0.8.6
typing_extensions        4.11.0
tyro                     0.8.4
tzdata                   2024.1
urllib3                  2.2.1
wandb                    0.17.0
wheel                    0.43.0
xxhash                   3.4.1
yarl                     1.9.4

Changing accelerate to accelerate<=0.29.3:
Package                  Version
------------------------ ---------------
accelerate               0.29.3
aiohttp                  3.9.5
aiosignal                1.3.1
annotated-types          0.6.0
async-timeout            4.0.3
attrs                    23.2.0
bitsandbytes             0.43.1
certifi                  2024.2.2
charset-normalizer       3.3.2
click                    8.1.7
datasets                 2.19.1
deepspeed                0.14.2+5f631abc
dill                     0.3.8
docker-pycreds           0.4.0
docstring_parser         0.16
einops                   0.8.0
eval_type_backport       0.2.0
exceptiongroup           1.2.1
filelock                 3.14.0
flash-attn               2.5.8
frozenlist               1.4.1
fsspec                   2024.3.1
gitdb                    4.0.11
GitPython                3.1.43
hf_transfer              0.1.6
hjson                    3.1.0
huggingface-hub          0.23.0
idna                     3.7
iniconfig                2.0.0
Jinja2                   3.1.4
markdown-it-py           3.0.0
MarkupSafe               2.1.5
mdurl                    0.1.2
mpmath                   1.3.0
multidict                6.0.5
multiprocess             0.70.16
networkx                 3.1
numpy                    1.24.4
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.4.127
nvidia-nvtx-cu12         12.1.105
packaging                24.0
pandas                   2.0.3
peft                     0.10.0
pillow                   10.3.0
pip                      24.0
platformdirs             4.2.1
pluggy                   1.5.0
protobuf                 3.20.1
psutil                   5.9.8
py-cpuinfo               9.0.0
pyarrow                  16.0.0
pyarrow-hotfix           0.6
pydantic                 2.7.1
pydantic_core            2.18.2
Pygments                 2.18.0
pynvml                   11.5.0
pytest                   8.2.0
python-dateutil          2.9.0.post0
pytz                     2024.1
PyYAML                   6.0.1
regex                    2024.5.10
requests                 2.31.0
rich                     13.7.1
safetensors              0.4.3
scipy                    1.10.1
sentencepiece            0.2.0
sentry-sdk               2.1.1
setproctitle             1.3.3
setuptools               69.5.1
shtab                    1.7.1
six                      1.16.0
smmap                    5.0.1
sympy                    1.12
text-generation          0.7.0
tokenizers               0.19.1
tomli                    2.0.1
torch                    2.3.0
torchaudio               2.3.0
torchvision              0.18.0
tqdm                     4.66.4
transformers             4.40.2
triton                   2.3.0
trl                      0.8.6
typing_extensions        4.11.0
tyro                     0.8.4
tzdata                   2024.1
urllib3                  2.2.1
wandb                    0.17.0
wheel                    0.43.0
xxhash                   3.4.1
yarl                     1.9.4




I am using code based on the code here:

Else, the basic steps are the following:

  1. Install the pip packages seen above, namely: pip install "accelerate<=0.29.3" pip install transformers accelerate peft bitsandbytes trl
  2. Use a QLora FSDP program
  3. Notice how errors occur with 0.3.0 but not 0.29.3

See an error like the following for 0.30.0:

[rank0]: Traceback (most recent call last):
[rank0]:   File "", line 387, in <module>
[rank0]:     trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/trl/trainer/", line 361, in train
[rank0]:     output = super().train(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/transformers/", line 1859, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/transformers/", line 2001, in _inner_training_loop
[rank0]:     self._fsdp_qlora_plugin_updates()
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/transformers/", line 4425, in _fsdp_qlora_plugin_updates
[rank0]:     fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
[rank0]:   File "/usr/local/lib/python3.8/dist-packages/peft/utils/", line 396, in fsdp_auto_wrap_policy
[rank0]:     transformer_cls = FullyShardedDataParallelPlugin.get_module_class_from_name(model, layer_class)
[rank0]: AttributeError: type object 'FullyShardedDataParallelPlugin' has no attribute 'get_module_class_from_name'
[rank1]: Traceback (most recent call last):
[rank1]:   File "", line 387, in <module>
[rank1]:     trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
[rank1]:   File "/usr/local/lib/python3.8/dist-packages/trl/trainer/", line 361, in train
[rank1]:     output = super().train(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.8/dist-packages/transformers/", line 1859, in train
[rank1]:     return inner_training_loop(
[rank1]:   File "/usr/local/lib/python3.8/dist-packages/transformers/", line 2001, in _inner_training_loop
[rank1]:     self._fsdp_qlora_plugin_updates()
[rank1]:   File "/usr/local/lib/python3.8/dist-packages/transformers/", line 4425, in _fsdp_qlora_plugin_updates
[rank1]:     fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
[rank1]:   File "/usr/local/lib/python3.8/dist-packages/peft/utils/", line 396, in fsdp_auto_wrap_policy
[rank1]:     transformer_cls = FullyShardedDataParallelPlugin.get_module_class_from_name(model, layer_class)
[rank1]: AttributeError: type object 'FullyShardedDataParallelPlugin' has no attribute 'get_module_class_from_name'
E0510 12:16:25.853937 140644343273280 torch/distributed/elastic/multiprocessing/] failed (exitcode: 1) local_rank: 0 (pid: 140) of binary: /usr/bin/python3
Traceback (most recent call last):
  File "/usr/local/bin/accelerate", line 8, in <module>
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/", line 46, in main
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/", line 1069, in launch_command
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/", line 718, in multi_gpu_launcher
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/", line 870, in run
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/", line 263, in launch_agent
    raise ChildFailedError(
============================================================ FAILED
  time      : 2024-05-10_12:16:25
  host      : f61090d2a6fd
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 141)
  error_file: <N/A>
  traceback : To enable traceback see:
Root Cause (first observed failure):
  time      : 2024-05-10_12:16:25
  host      : f61090d2a6fd
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 140)
  error_file: <N/A>
  traceback : To enable traceback see:

Expected behavior

I expect training to occur without issues. This occurs when I use accelerate 0.29.3

muellerzr commented 1 month ago

cc @younesbelkada @pacman100

BenjaminBossan commented 1 month ago

@mallorbc Could you try installing PEFT from main and check if the error persists?

mallorbc commented 1 month ago

So use latest accelerate and install peft from main?

I will do the following: pip install transformers bitsandbytes trl accelerate pip install git+

I will let you know

mallorbc commented 1 month ago

I did the above setup. Here is my pip list: Package Version

accelerate 0.30.1 aiohttp 3.9.5 aiosignal 1.3.1 annotated-types 0.6.0 async-timeout 4.0.3 attrs 23.2.0 bitsandbytes 0.43.1 certifi 2024.2.2 charset-normalizer 3.3.2 click 8.1.7 datasets 2.19.1 deepspeed 0.14.2+5f631abc dill 0.3.8 docker-pycreds 0.4.0 docstring_parser 0.16 einops 0.8.0 eval_type_backport 0.2.0 exceptiongroup 1.2.1 filelock 3.14.0 flash-attn 2.5.8 frozenlist 1.4.1 fsspec 2024.3.1 gitdb 4.0.11 GitPython 3.1.43 hf_transfer 0.1.6 hjson 3.1.0 huggingface-hub 0.23.0 idna 3.7 iniconfig 2.0.0 Jinja2 3.1.4 markdown-it-py 3.0.0 MarkupSafe 2.1.5 mdurl 0.1.2 mpmath 1.3.0 multidict 6.0.5 multiprocess 0.70.16 networkx 3.1 ninja numpy 1.24.4 nvidia-cublas-cu12 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 nvidia-cufft-cu12 nvidia-curand-cu12 nvidia-cusolver-cu12 nvidia-cusparse-cu12 nvidia-nccl-cu12 2.20.5 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu12 12.1.105 packaging 24.0 pandas 2.0.3 peft 0.11.1.dev0 pillow 10.3.0 pip 24.0 platformdirs 4.2.2 pluggy 1.5.0 protobuf 3.20.1 psutil 5.9.8 py-cpuinfo 9.0.0 pyarrow 16.1.0 pyarrow-hotfix 0.6 pydantic 2.7.1 pydantic_core 2.18.2 Pygments 2.18.0 pynvml 11.5.0 pytest 8.2.0 python-dateutil 2.9.0.post0 pytz 2024.1 PyYAML 6.0.1 regex 2024.5.15 requests 2.31.0 rich 13.7.1 safetensors 0.4.3 scipy 1.10.1 sentencepiece 0.2.0 sentry-sdk 2.2.0 setproctitle 1.3.3 setuptools 69.5.1 shtab 1.7.1 six 1.16.0 smmap 5.0.1 sympy 1.12 text-generation 0.7.0 tokenizers 0.19.1 tomli 2.0.1 torch 2.3.0 torchaudio 2.3.0 torchvision 0.18.0 tqdm 4.66.4 transformers 4.40.2 triton 2.3.0 trl 0.8.6 typing_extensions 4.11.0 tyro 0.8.4 tzdata 2024.1 urllib3 2.2.1 wandb 0.17.0 wheel 0.43.0 xxhash 3.4.1 yarl 1.9.4

I can confirm that this lead to successful fine-tuning with QLora with FSDP. However, QDora seems to be broken.

When I try doing FSDP QDora, I get the following issue: rank0: Traceback (most recent call last): rank0: File "", line 399, in

rank0: File "/usr/local/lib/python3.8/dist-packages/trl/trainer/", line 361, in train rank0: output = super().train(*args, kwargs) rank0: File "/usr/local/lib/python3.8/dist-packages/transformers/", line 1859, in train rank0: return inner_training_loop( rank0: File "/usr/local/lib/python3.8/dist-packages/transformers/", line 2002, in _inner_training_loop rank0: self.model = self.accelerator.prepare(self.model) rank0: File "/usr/local/lib/python3.8/dist-packages/accelerate/", line 1292, in prepare rank0: result = tuple( rank0: File "/usr/local/lib/python3.8/dist-packages/accelerate/", line 1293, in rank0: self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) rank0: File "/usr/local/lib/python3.8/dist-packages/accelerate/", line 1169, in _prepare_one rank0: return self.prepare_model(obj, device_placement=device_placement) rank0: File "/usr/local/lib/python3.8/dist-packages/accelerate/", line 1459, in prepare_model rank0: model = FSDP(model, kwargs) rank0: File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/", line 485, in init

rank0: File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/", line 101, in _auto_wrap rank0: _recursive_wrap(recursive_wrap_kwargs, root_kwargs) # type: ignorearg-type: File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/", line 543, in _recursive_wrap rank0: wrapped_child, num_wrapped_params = _recursive_wrap( rank0: File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/", line 543, in _recursive_wrap rank0: wrapped_child, num_wrapped_params = _recursive_wrap( rank0: File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/", line 543, in _recursive_wrap rank0: wrapped_child, num_wrapped_params = _recursive_wrap( rank0: Previous line repeated 2 more times: File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/", line 561, in _recursive_wrap rank0: return _wrap(module, wrapper_cls, kwargs), nonwrapped_numel rank0: File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/", line 490, in _wrap rank0: return wrapper_cls(module, kwargs) rank0: File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/", line 511, in init

rank0: File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/", line 598, in _init_param_handle_from_module rank0: _init_param_handle_from_params(state, managed_params, fully_sharded_module) rank0: File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/", line 610, in _init_param_handle_from_params rank0: handle = FlatParamHandle( rank0: File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/", line 582, in init

rank0: File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/", line 632, in _init_flat_param_and_metadata rank0: ) = self._validate_tensors_to_flatten(params) rank0: File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/", line 768, in _validate_tensors_to_flatten rank0: raise ValueError("Cannot flatten integer dtype tensors") rank0: ValueError: Cannot flatten integer dtype tensors

jaywongs commented 1 month ago

ith QLora with FS

I used the exactly version you mentioned ,and with fsdp+qlora, i got the same "ValueError: Cannot flatten integer dtype tensors"

BenjaminBossan commented 1 month ago

For QLoRA training with FSDP, please check the updated bitsandbytes docs.

As for QDoRA: Training with FSDP should be fixed in If you install from the latest PEFT main, it should thus work. Please also check the PR description on how this was tested.

github-actions[bot] commented 1 week ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.