Open primary-studyer opened 5 months ago
我用的是8卡A800 80G,40G的卡我没有试过,如果per device batch size=1还会爆显存的话建议使用lora
lora微调的效果有点不说人话
你运行成功了吗?为什么我运行不起来,报错
deepspeed --num_gpus 8 ../src/train_bash.py \ --deepspeed ds_z2_config.json \ --dataset_dir ../data \ --stage orpo \ --do_train \ --model_name_or_path /nfsdata/gpu007/dingfei/commondata/huggingface/Meta-Llama-3-8B-Instruct \ --dataset dpo_mix_en,dpo_mix_zh \ --template llama3 \ --finetuning_type full \ --output_dir /nfsdata/gpu007/dingfei/commondata/huggingface/Meta-Llama-3-8B-Instruct-saves/shenzhi-wang/Llama3-8B-Chinese-Chat-v1-test1 \ --per_device_train_batch_size 2 \ --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 4 \ --lr_scheduler_type cosine \ --log_level info \ --logging_steps 5 \ --save_strategy epoch \ --save_total_limit 3 \ --save_steps 100 \ --learning_rate 5e-6 \ --num_train_epochs 3.0 \ --plot_loss \ --do_eval false \ --max_steps -1 \ --bf16 true \ --seed 42 \ --warmup_ratio 0.1 \ --cutoff_len 8192 \ --flash_attn true \ --orpo_beta 0.05 \ --optim paged_adamw_32bit
deepspeed --include localhost:0,1,2,3,4,5,6,7 src/train_bash.py \ --deepspeed config_zero3.json \ --stage orpo \ --do_train \ --model_name_or_path llama3/Meta-Llama-3-8B-Instruct/ \ --dataset dpo_mix_en,dpo_mix_zh \ --template llama3 \ --finetuning_type full \ --output_dir output_llm_model/llama-3-8b-orpo-full/ \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 2 \ --lr_scheduler_type cosine \ --log_level info \ --logging_steps 5 \ --save_strategy epoch \ --save_total_limit 3 \ --save_steps 100 \ --learning_rate 5e-6 \ --num_train_epochs 3.0 \ --plot_loss \ --do_eval false \ --max_steps -1 \ --bf16 true \ --seed 42 \ --warmup_ratio 0.1 \ --cutoff_len 8192 \ --flash_attn true \ --orpo_beta 0.05 \ --optim paged_adamw_32bit
你的v1版本 --per_device_train_batch_size 2 \ --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 4
我这边参数是: --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 8 \ 全参微调orpo跑不起来,8卡 A100(40G) 配置 gradient_accumulation_steps从8分别调整到4, 2, 1也跑不起来。oom config_zero3换成config_zero2也是oom
请问你的微调orpo配置是多少?大概多少能跑起来?