OpenLLMAI / OpenRLHF

An Easy-to-use, Scalable and High-performance RLHF Framework (70B+ PPO Full Tuning & Iterative DPO & LoRA & Mixtral)
https://openrlhf.readthedocs.io/
Apache License 2.0
1.72k stars 160 forks source link

CUDA out of memory when i run train_ppo_llama_ray.sh on 4 RTX 4090(24G) #275

Closed libowen424 closed 2 months ago

libowen424 commented 2 months ago

My configuration: ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json='{"working_dir": "/openrlhf", "pip": "/openrlhf/requirements.txt"}' \ -- python3 examples/train_ppo_ray.py \ --ref_num_nodes 1 \ --ref_num_gpus_per_node 1 \ --reward_num_nodes 1 \ --reward_num_gpus_per_node 1 \ --critic_num_nodes 1 \ --critic_num_gpus_per_node 1 \ --actor_num_nodes 1 \ --actor_num_gpus_per_node 1 \ --pretrain /root/.cache/huggingface/hub/llama-2-7b-chat-hf \ --reward_pretrain /root/.cache/huggingface/hub/models--OpenLLMAI--Llama-2-7b-rm-anthropic_hh-lmsys-oasst-webgpt/snapshots/a982afeed00fac9767d53aecde5b88947b1be194 \ --save_path /openrlhf/examples/test_scripts/ckpt/7b_llama \ --micro_train_batch_size 8 \ --train_batch_size 64 \ --micro_rollout_batch_size 16 \ --rollout_batch_size 1024 \ --max_epochs 1 \ --prompt_max_len 1024 \ --generate_max_len 1024 \ --zero_stage 2 \ --bf16 \ --actor_learning_rate 5e-7 \ --critic_learning_rate 9e-6 \ --init_kl_coef 0.01 \ --prompt_data Open-Orca/OpenOrca,Dahoas/full-hh-rlhf,tasksource/oasst1_pairwise_rlhf_reward \ --prompt_data_probs 0.4,0.5,0.1 \ --max_samples 80000 \ --normalize_reward \ --actor_init_on_gpu \ --adam_offload \ --flash_attn \ --gradient_checkpointing \ --lora_rank 4

What should I change? Thanks so much!

hijkzzz commented 2 months ago

reduce the --micro_rollout_batch_size --micro_train_batch_size to 1, 1

libowen424 commented 2 months ago

reduce the --micro_rollout_batch_size --micro_train_batch_size to 1, 1