huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
8.72k stars 1.07k forks source link

FSDP Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32 #1723

Open qZhang88 opened 3 weeks ago

qZhang88 commented 3 weeks ago

running dpo with Qwen meet flatten problem. FSDP config as follow

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
num_machines: 1
num_processes: 2
main_training_function: main
mixed_precision: bf16
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
vwxyzjn commented 3 weeks ago

Could you share mode detailed error message?

qZhang88 commented 3 weeks ago

detailed codes as following.

[rank0]: Traceback (most recent call last):
[rank0]:   File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 270, in <module>
[rank0]:     main()
[rank0]:   File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 261, in main
[rank0]:     trainer.train()
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 1885, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 2032, in _inner_training_loop
[rank0]:     self.model = self.accelerator.prepare(self.model)
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1292, in prepare
[rank0]:     result = tuple(
[rank0]:              ^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1293, in <genexpr>
[rank0]:     self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1169, in _prepare_one
[rank0]:     return self.prepare_model(obj, device_placement=device_placement)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1459, in prepare_model
[rank0]:     model = FSDP(model, **kwargs)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 485, in __init__
[rank0]:     _auto_wrap(
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank0]:     _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank0]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]:                                         ^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank0]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]:                                         ^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank0]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]:                                         ^^^^^^^^^^^^^^^^
[rank0]:   [Previous line repeated 2 more times]
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
[rank0]:     return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
[rank0]:     return wrapper_cls(module, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 511, in __init__
[rank0]:     _init_param_handle_from_module(
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 598, in _init_param_handle_from_module
[rank0]:     _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 610, in _init_param_handle_from_params
[rank0]:     handle = FlatParamHandle(
[rank0]:              ^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 582, in __init__
[rank0]:     self._init_flat_param_and_metadata(
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 632, in _init_flat_param_and_metadata
[rank0]:     ) = self._validate_tensors_to_flatten(params)
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten
[rank0]:     raise ValueError(
[rank0]: ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32
[rank1]: Traceback (most recent call last):
[rank1]:   File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 270, in <module>
[rank1]:     main()
[rank1]:   File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 261, in main
[rank1]:     trainer.train()
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 1885, in train
[rank1]:     return inner_training_loop(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 2032, in _inner_training_loop
[rank1]:     self.model = self.accelerator.prepare(self.model)
[rank1]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1292, in prepare
[rank1]:     result = tuple(
[rank1]:              ^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1293, in <genexpr>
[rank1]:     self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank1]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1169, in _prepare_one
[rank1]:     return self.prepare_model(obj, device_placement=device_placement)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1459, in prepare_model
[rank1]:     model = FSDP(model, **kwargs)
[rank1]:             ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 485, in __init__
[rank1]:     _auto_wrap(
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank1]:     _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
[rank1]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank1]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank1]:                                         ^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank1]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank1]:                                         ^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank1]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank1]:                                         ^^^^^^^^^^^^^^^^
[rank1]:   [Previous line repeated 2 more times]
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
[rank1]:     return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
[rank1]:     return wrapper_cls(module, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 511, in __init__
[rank1]:     _init_param_handle_from_module(
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 598, in _init_param_handle_from_module
[rank1]:     _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 610, in _init_param_handle_from_params
[rank1]:     handle = FlatParamHandle(
[rank1]:              ^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 582, in __init__
[rank1]:     self._init_flat_param_and_metadata(
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 632, in _init_flat_param_and_metadata
[rank1]:     ) = self._validate_tensors_to_flatten(params)
[rank1]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten
[rank1]:     raise ValueError(
[rank1]: ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32

I am running a code modified from this script https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py

And I am runing with QLoRA. And source code for BnB config is modified to support param bnb_4bit_quant_storage

    --load_in_4bit True \
    --use_bnb_nested_quant True \
    --bnb_4bit_quant_storage bfloat16 \

If QLoRA is not used. FSDP is all fine, but during training, it will meet OOM error with some long training example. So I am trying to use FDSP with QLoRA.

qZhang88 commented 2 weeks ago

Check note from SFTTrainer, the error is caused by peft_module_casting_to_bf16 or prepare_model_for_kbit_training

                # Below is to support QLoRA + FSDP / DS-Zero3 - one should never call
                # peft_module_casting_to_bf16 or prepare_model_for_kbit_training when doing
                # QLoRA + FSDP / DS-Zero3
Minami-su commented 6 days ago

+1