hiyouga / LLaMA-Factory

Unified Efficient Fine-Tuning of 100+ LLMs (ACL 2024)
https://arxiv.org/abs/2403.13372
Apache License 2.0
32.7k stars 4.01k forks source link

dpo后推理速度变慢将近10倍 #1546

Closed fst813 closed 11 months ago

fst813 commented 11 months ago

在模型上直接进行dpo训练,没有加载sft的lora参数,训练参数如下:

src/train_bash.py \
        --stage dpo \
        --dpo_bet 0.3 \
        --model_name_or_path model_path/ \
        --do_train \
        --dataset pattern_dpo_sample_2 \
        --template codellama \
        --finetuning_type lora \
        --lora_target all \
        --output_dir out_dir \
        --overwrite_cache \
        --overwrite_output_dir \
        --per_device_train_batch_size 6 \
        --gradient_accumulation_steps 1 \
        --flash_attn True \
        --lr_scheduler_type linear \
        --logging_steps 10 \
        --save_steps 100 \
        --cutoff_len 4096 \
        --learning_rate 1e-6 \
        --num_train_epochs 5.0 \
        --plot_loss \
        --fp16 \
        --use_fast_tokenizer False \
        --lora_alpha 16 \
        --lora_rank 8 \
        --lora_dropout 0.1

训练一切正常。 推理参数如下:

python3 src/train_bash.py \
    --stage sft \
    --model_name_or_path $MODEL_PATH \
    --do_predict \
    --dataset v6_1008_compare_test_codellama_3_3500_llamatoken_fix_newfomate \
    --template codellama \
    --flash_attn True \
    --cutoff_len 4096 \
    --use_fast_tokenizer False \
    --finetuning_type lora \
    --lora_target all \
    --checkpoint_dir lora_dir \
    --output_dir $OUTPUT_DIR \
    --per_device_eval_batch_size 10 \
    --max_samples 1035 \
    --num_beams 1 \
    --max_new_tokens 512 \
    --predict_with_generate

推理速度比之前sft的慢了将近10倍,是我参数有什么问题吗?

fst813 commented 11 months ago

合并再推理速度同样很慢

Qyf007 commented 5 months ago

同问

fst813 commented 5 months ago

@Qyf007 应该是训崩了,模型没有正确终止

wanghanone commented 1 month ago

@Qyf007 应该是训崩了,模型没有正确终止 怎么看是不是训崩了呢?

wanghanone commented 1 month ago

{ "best_metric": null, "best_model_checkpoint": null, "epoch": 2.6879999999999997, "eval_steps": 500, "global_step": 21, "is_hyper_param_search": false, "is_local_process_zero": true, "is_world_process_zero": true, "log_history": [ { "epoch": 2.69, "step": 21, "total_flos": 0.0, "train_loss": 0.4635912577311198, "train_runtime": 31519.1577, "train_samples_per_second": 0.095, "train_steps_per_second": 0.001 } ], "logging_steps": 100, "max_steps": 21, "num_input_tokens_seen": 0, "num_train_epochs": 3, "save_steps": 500, "total_flos": 0.0, "train_batch_size": 1, "trial_name": null, "trial_params": null } 我训练完state文件是这样的,怎么能看出来是不是训崩了呢?

fst813 commented 1 month ago

@wanghanone 看看推理结果是不是乱了或者停不下来了

wanghanone commented 1 month ago

@wanghanone 看看推理结果是不是乱了或者停不下来了

有的时候会停不下来,但是大部分还是可以正常回答。