Open qZhang88 opened 3 weeks ago
Could you share mode detailed error message?
detailed codes as following.
[rank0]: Traceback (most recent call last):
[rank0]: File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 270, in <module>
[rank0]: main()
[rank0]: File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 261, in main
[rank0]: trainer.train()
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 1885, in train
[rank0]: return inner_training_loop(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 2032, in _inner_training_loop
[rank0]: self.model = self.accelerator.prepare(self.model)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1292, in prepare
[rank0]: result = tuple(
[rank0]: ^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1293, in <genexpr>
[rank0]: self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1169, in _prepare_one
[rank0]: return self.prepare_model(obj, device_placement=device_placement)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1459, in prepare_model
[rank0]: model = FSDP(model, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 485, in __init__
[rank0]: _auto_wrap(
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank0]: _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank0]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank0]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank0]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: [Previous line repeated 2 more times]
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
[rank0]: return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
[rank0]: return wrapper_cls(module, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 511, in __init__
[rank0]: _init_param_handle_from_module(
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 598, in _init_param_handle_from_module
[rank0]: _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 610, in _init_param_handle_from_params
[rank0]: handle = FlatParamHandle(
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 582, in __init__
[rank0]: self._init_flat_param_and_metadata(
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 632, in _init_flat_param_and_metadata
[rank0]: ) = self._validate_tensors_to_flatten(params)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten
[rank0]: raise ValueError(
[rank0]: ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32
[rank1]: Traceback (most recent call last):
[rank1]: File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 270, in <module>
[rank1]: main()
[rank1]: File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 261, in main
[rank1]: trainer.train()
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 1885, in train
[rank1]: return inner_training_loop(
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 2032, in _inner_training_loop
[rank1]: self.model = self.accelerator.prepare(self.model)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1292, in prepare
[rank1]: result = tuple(
[rank1]: ^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1293, in <genexpr>
[rank1]: self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1169, in _prepare_one
[rank1]: return self.prepare_model(obj, device_placement=device_placement)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1459, in prepare_model
[rank1]: model = FSDP(model, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 485, in __init__
[rank1]: _auto_wrap(
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank1]: _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank1]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank1]: ^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank1]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank1]: ^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank1]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank1]: ^^^^^^^^^^^^^^^^
[rank1]: [Previous line repeated 2 more times]
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
[rank1]: return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
[rank1]: return wrapper_cls(module, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 511, in __init__
[rank1]: _init_param_handle_from_module(
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 598, in _init_param_handle_from_module
[rank1]: _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 610, in _init_param_handle_from_params
[rank1]: handle = FlatParamHandle(
[rank1]: ^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 582, in __init__
[rank1]: self._init_flat_param_and_metadata(
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 632, in _init_flat_param_and_metadata
[rank1]: ) = self._validate_tensors_to_flatten(params)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten
[rank1]: raise ValueError(
[rank1]: ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32
I am running a code modified from this script
https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py
And I am runing with QLoRA. And source code for BnB config is modified to support param bnb_4bit_quant_storage
--load_in_4bit True \
--use_bnb_nested_quant True \
--bnb_4bit_quant_storage bfloat16 \
If QLoRA is not used. FSDP is all fine, but during training, it will meet OOM error with some long training example. So I am trying to use FDSP with QLoRA.
Check note from SFTTrainer, the error is caused by peft_module_casting_to_bf16
or prepare_model_for_kbit_training
# Below is to support QLoRA + FSDP / DS-Zero3 - one should never call
# peft_module_casting_to_bf16 or prepare_model_for_kbit_training when doing
# QLoRA + FSDP / DS-Zero3
+1
running dpo with Qwen meet flatten problem. FSDP config as follow