It can train normally when using BasicTrainer, however when using FSDPTrainer, I met:
Traceback (most recent call last):
File "/direct-preference-optimization/train.py", line 127, in main
mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model), join=True)
File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 282, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 238, in start_processes
while not context.join():
File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 189, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 76, in _wrap
fn(i, *args)
File "direct-preference-optimization/train.py", line 43, in worker_main
trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, rank=rank, world_size=world_size)
File "direct-preference-optimization/trainers.py", line 469, in __init__
self.policy = FSDP(policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy)
File "python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in __init__
_auto_wrap(
File "python3.9/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 102, in _auto_wrap
_recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
wrapped_child, num_wrapped_params = _recursive_wrap(
[Previous line repeated 2 more times]
File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 562, in _recursive_wrap
return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 491, in _wrap
return wrapper_cls(module, **kwargs)
File "python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__
_init_param_handle_from_module(
File "python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 603, in _init_param_handle_from_module
_init_param_handle_from_params(state, managed_params, fully_sharded_module)
File "python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 615, in _init_param_handle_from_params
handle = FlatParamHandle(
File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 583, in __init__
self._init_flat_param_and_metadata(
File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 633, in _init_flat_param_and_metadata
) = self._validate_tensors_to_flatten(params)
File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 771, in _validate_tensors_to_flatten
raise ValueError(
ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32
I tried to use bfloat16 in Lora modules, but other ValueErrors occurs.
I tried use_orig_params=True, it doesn't work.
ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32
I tried running this code on two 80GB A100 and added PEFT's Lora in train.py:
It can train normally when using BasicTrainer, however when using FSDPTrainer, I met:
I tried to use bfloat16 in Lora modules, but other ValueErrors occurs. I tried use_orig_params=True, it doesn't work.
How to solve it?