hiyouga / LLaMA-Factory

Unified Efficient Fine-Tuning of 100+ LLMs (ACL 2024)
https://arxiv.org/abs/2403.13372
Apache License 2.0
34.03k stars 4.19k forks source link

RM模型是全参训练的,无法全参训练PPO吗 #2751

Closed xienan0326 closed 8 months ago

xienan0326 commented 8 months ago

Reminder

Reproduction

python -m torch.distributed.run \ --nproc_per_node 8 \ --nnodes 1 \ src/train_bash.py \ --deepspeed ds_config3.json \ --stage ppo \ --do_train \ --model_name_or_path out_dir/sft_test \ --reward_model out_dir/rm_test \ --dataset comparison_gpt4_zh \ --template default \ --finetuning_type full \ --output_dir out_dir/ppo_test \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 16 \ --lr_scheduler_type cosine \ --top_k 0 \ --top_p 0.9 \ --logging_steps 10 \ --save_steps 1000 \ --learning_rate 1e-5 \ --num_train_epochs 1.0 \ --plot_loss \ --fp16

Expected behavior

如何修复

System Info

Traceback (most recent call last): File "/workspace/LLaMA-Factory/src/train_bash.py", line 14, in main() File "/workspace/LLaMA-Factory/src/train_bash.py", line 5, in main run_exp() File "/workspace/LLaMA-Factory/src/llmtuner/train/tuner.py", line 26, in run_exp model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) File "/workspace/LLaMA-Factory/src/llmtuner/hparams/parser.py", line 92, in get_train_args model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) File "/workspace/LLaMA-Factory/src/llmtuner/hparams/parser.py", line 78, in _parse_train_args return _parse_args(parser, args) File "/workspace/LLaMA-Factory/src/llmtuner/hparams/parser.py", line 45, in _parse_args (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) File "/opt/conda/lib/python3.10/site-packages/transformers/hf_argparser.py", line 338, in parse_args_into_dataclasses obj = dtype(**inputs) File "", line 34, in init File "/workspace/LLaMA-Factory/src/llmtuner/hparams/finetuning_args.py", line 200, in __post_init__ raise ValueError("reward_model_type cannot be lora for Freeze/Full PPO training.") ValueError: reward_model_type cannot be lora for Freeze/Full PPO training.

Others

No response

hiyouga commented 8 months ago

--reward_model_type full