CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.51k stars 471 forks source link

Error when running Ray Tune to launch hyperparameter sweep #597

Open Jing-L97 opened 4 months ago

Jing-L97 commented 4 months ago

🐛 Describe the bug

Hi we encountered the DistributedDataParallel issue when running the example code with Ray Optimization, in which we set the Distributed Type: no:

ray start --head --port=6379 python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml --accelerate_config configs/accelerate/ddp.yaml --num_gpus 4 examples/ppo_sentiments.py

Here's the Traceback Error that we encountered

Traceback (most recent call last):
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/_private/auto_init_hook.py", line 24, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/_private/worker.py", line 2524, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::_Inner.train() (pid=1885930, ip=10.20.0.6, actor_id=6d08bc117a6b35cc7647003f01000000, repr=AccelerateTrainer)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 375, in train
    raise skipped from exception_cause(skipped)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/train/_internal/utils.py", line 54, in check_for_failure
    ray.get(object_ref)
ray.exceptions.RayTaskError(AttributeError): ray::_RayTrainWorker__execute.get_next() (pid=1886047, ip=10.20.0.6, actor_id=dd5dcbaf834905aa00b49be601000000, repr=<ray.train._internal.worker_group.RayTrainWorker object at 0x7f3d4c6d20a0>)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/train/_internal/worker_group.py", line 32, in __execute
    raise skipped from exception_cause(skipped)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/train/_internal/utils.py", line 129, in discard_return_wrapper
    train_func(*args, **kwargs)
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/ray/train/huggingface/accelerate/accelerate_trainer.py", line 411, in _accelerate_train_loop_per_worker
    return train_loop_per_worker(*args, **kwargs)
  File "/scratch2/jliu/CF_RL/scripts/trlx/examples/ppo_sentiments.py", line 47, in main
    trlx.train(
  File "/scratch2/jliu/CF_RL/scripts/trlx/trlx/trlx.py", line 92, in train
    trainer = get_trainer(config.train.trainer)(
  File "/scratch2/jliu/CF_RL/scripts/trlx/trlx/trainer/accelerate_ppo_trainer.py", line 74, in __init__
    if not hasattr(self.model, "frozen_head") and not self.model.peft_type:
  File "/scratch2/jliu/.conda/envs/RL/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'DistributedDataParallel' object has no attribute 'peft_type'

The same error occurred when we changed the config file into the iml setting below

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: 'NO'
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Thank you very much!

Which trlX version are you using?

https://github.com/CarperAI/trlx/tree/3340c2f3a56d1d14fdd5f13ad575121fa26b6d92

Additional system and package information

transformers==4.32.0,python==3.9

arxaqapi commented 4 months ago

There seem to be an issue with the if statements on line 74, 398 and 424 in the trlx/trainer/accelerate_ppo_trainer.py file.

The check for self.model.peft_type should be made with hasattr like this:

if ... and hasattr(self.model, "peft_type")