OpenRLHF / OpenRLHF

An Easy-to-use, Scalable and High-performance RLHF Framework (70B+ PPO Full Tuning & Iterative DPO & LoRA & RingAttention)
https://openrlhf.readthedocs.io/
Apache License 2.0
2.64k stars 248 forks source link

[rank3]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cpu! #325

Closed xiechengmude closed 5 months ago

xiechengmude commented 5 months ago

Error got when I ran train_ppo_llama.sh

set -x

read -r -d '' training_commands <<EOF
examples/train_ppo.py \
    --pretrain OpenLLMAI/Llama-2-7b-sft-model-ocra-500k \
    --reward_pretrain OpenLLMAI/Llama-2-7b-rm-anthropic_hh-lmsys-oasst-webgpt \
    --save_path ./ckpt/7b_llama \
    --save_steps -1 \
    --logging_steps 1 \
    --eval_steps -1 \
    --micro_train_batch_size 2 \
    --train_batch_size 128 \
    --micro_rollout_batch_size 4 \
    --rollout_batch_size 1024 \
    --max_epochs 1 \
    --prompt_max_len 1024 \
    --generate_max_len 1024 \
    --zero_stage 3 \
    --bf16 \
    --actor_learning_rate 5e-7 \
    --critic_learning_rate 9e-6 \
    --init_kl_coef 0.01 \
    --prompt_data Open-Orca/OpenOrca,Dahoas/full-hh-rlhf,tasksource/oasst1_pairwise_rlhf_reward \
    --prompt_data_probs 0.4,0.5,0.1 \
    --max_samples 80000 \
    --normalize_reward \
    --actor_init_on_gpu \
    --flash_attn \
    --gradient_checkpointing
EOF
     # --wandb [WANDB_TOKENS] or True (use wandb login command)

if [[ ${1} != "slurm" ]]; then
    deepspeed $training_commands

Episode [1/1]: 0%| | 0/2500 [00:00<?, ?it/s]/root/miniconda3/envs/openrlhf/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:535: UserWarning:num_beamsis set to 1. However,early_stoppingis set toTrue-- this flag is only used in beam-based generation modes. You should setnum_beams>1or unsetearly_stopping`. warnings.warn( rank3: Traceback (most recent call last): rank3: File "/workspace/OpenRLHF/examples/train_ppo.py", line 347, in

rank3: File "/workspace/OpenRLHF/examples/train_ppo.py", line 240, in train

rank3: File "/root/.local/lib/python3.10/site-packages/openrlhf/trainer/ppo_trainer.py", line 176, in fit rank3: experience = self.experience_maker.make_experience(rand_prompts, self.generate_kwargs) rank3: File "/root/miniconda3/envs/openrlhf/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context rank3: return func(*args, kwargs) rank3: File "/root/.local/lib/python3.10/site-packages/openrlhf/trainer/ppo_utils/experience_maker.py", line 120, in make_experience rank3: sequences, attention_mask, action_mask = self.actor.generate(inputs, *generate_kwargs) rank3: File "/root/miniconda3/envs/openrlhf/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context rank3: return func(args, kwargs) rank3: File "/root/.local/lib/python3.10/site-packages/openrlhf/models/actor.py", line 136, in generate rank3: return self.process_sequences(sequences, input_ids.size(1), eos_token_id, pad_token_id) rank3: File "/root/.local/lib/python3.10/site-packages/openrlhf/models/actor.py", line 156, in process_sequences rank3: mask = (mask <= eos_indices) & (mask >= first_token_indices) rank3: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cpu! [2024-06-14 18:40:51,919] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 289741 [2024-06-14 18:40:52,442] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 289742 [2024-06-14 18:40:52,994] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 289743 [2024-06-14 18:40:53,500] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 289744 [2024-06-14 18:40:53,501] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 289745 [2024-06-14 18:40:54,047] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 289746 [2024-06-14 18:40:54,596] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 289747 [2024-06-14 18:40:55,183] [INFO] [launch.py:319:sigkill_handler] Killing subprocess 289748`

hijkzzz commented 5 months ago

Use deepspeed <= 0.14.0

cassanof commented 5 months ago

same issue, using deepspeed 0.13.5

cassanof commented 5 months ago

checking out release 0.3.0 made the error go away