hiyouga / LLaMA-Factory

Unified Efficient Fine-Tuning of 100+ LLMs (ACL 2024)
https://arxiv.org/abs/2403.13372
Apache License 2.0
35.4k stars 4.36k forks source link

DPO export and predict #1459

Closed RavidLightricks closed 1 year ago

RavidLightricks commented 1 year ago

Since DPO workflow doesn't support do_predict, I'm trying to export the model and then run do_predict with stf workflow. But the predictions I'm getting are empty strings.

python src/export_model.py \
    --model_name_or_path bigscience/bloomz-7b1 \
    --template my_template \
    --finetuning_type lora \
    --lora_target query_key_value \
    --checkpoint_dir path_to_dpo_checkpoint \
    --export_dir export_path \
    --stage dpo

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --model_name_or_path bigscience/bloomz-7b1 \
    --do_predict \
    --checkpoint_dir export_path \
    --output_dir export_path \
    --dataset my_dataset \
    --template my_template \
    --finetuning_type full \
    --predict_with_generate

What am I doing wrong?

hiyouga commented 1 year ago

Use the export_path as the model_name_or_path without a checkpoint_dir after the model exporting, and specify a different output_dir

RavidLightricks commented 1 year ago

@hiyouga still, i'm getting empty predictions, even when I'm using my own prediction script which is working on SFT models. I believe something in my export is not right.

hiyouga commented 1 year ago

try exporting the SFT and DPO weights separately

RavidLightricks commented 1 year ago

I think I might be missing something. In the export step, do I still need to supply the reference model? Maybe you guys can add it to the README, I also noticed that the argument dpo_ref_model was added.

hiyouga commented 1 year ago

ref_model is only used for computing metrics at evaluation. Here is the workflow we recommend:

  1. train a SFT model with stage=sft
  2. merge the SFT weights via exporting model
  3. use the export_dir as a new base model to train a DPO model with stage=dpo
  4. use cli_demo to check if the model can generate sentences by providing the DPO weights as the checkpoint_dir
  5. keep the parameters in step 4 and obtain generated predictions via stage=sft and do_predict
  6. merge the DPO weights into the base model (optional)