eric-mitchell / direct-preference-optimization

Reference implementation for DPO (Direct Preference Optimization)
Apache License 2.0
2.18k stars 180 forks source link

ValueError when using peft on FSDPTrainer #90

Open AragornHorse opened 1 week ago

AragornHorse commented 1 week ago

ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32

I tried running this code on two 80GB A100 and added PEFT's Lora in train.py:

peft_config = LoraConfig(                                                                                                                  
     r=config.lora.r, 
    lora_alpha=config.lora.alpha,
    lora_dropout=config.lora.dropout                                                    
)

policy = get_peft_model(policy, peft_config)

It can train normally when using BasicTrainer, however when using FSDPTrainer, I met:

Traceback (most recent call last):
  File "/direct-preference-optimization/train.py", line 127, in main
    mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, policy, reference_model), join=True)
  File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 282, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 238, in start_processes
    while not context.join():
  File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 189, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "python3.9/site-packages/torch/multiprocessing/spawn.py", line 76, in _wrap
    fn(i, *args)
  File "direct-preference-optimization/train.py", line 43, in worker_main
    trainer = TrainerClass(policy, config, config.seed, config.local_run_dir, reference_model=reference_model, rank=rank, world_size=world_size)
  File "direct-preference-optimization/trainers.py", line 469, in __init__
    self.policy = FSDP(policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy)
  File "python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in __init__
    _auto_wrap(
  File "python3.9/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 102, in _auto_wrap
    _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 544, in _recursive_wrap
    wrapped_child, num_wrapped_params = _recursive_wrap(
  [Previous line repeated 2 more times]
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 562, in _recursive_wrap
    return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
  File "python3.9/site-packages/torch/distributed/fsdp/wrap.py", line 491, in _wrap
    return wrapper_cls(module, **kwargs)
  File "python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__
    _init_param_handle_from_module(
  File "python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 603, in _init_param_handle_from_module
    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
  File "python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 615, in _init_param_handle_from_params
    handle = FlatParamHandle(
  File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 583, in __init__
    self._init_flat_param_and_metadata(
  File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 633, in _init_flat_param_and_metadata
    ) = self._validate_tensors_to_flatten(params)
  File "python3.9/site-packages/torch/distributed/fsdp/_flat_param.py", line 771, in _validate_tensors_to_flatten
    raise ValueError(
ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32

I tried to use bfloat16 in Lora modules, but other ValueErrors occurs. I tried use_orig_params=True, it doesn't work.

How to solve it?