hiyouga / LLaMA-Factory

Efficiently Fine-Tune 100+ LLMs in WebUI (ACL 2024)
https://arxiv.org/abs/2403.13372
Apache License 2.0
30.74k stars 3.79k forks source link

Reward Model训练好之后使用src/train_bash.py推理和src/api_demo.py推理,score量纲不一致 #2559

Closed dayL-W closed 6 months ago

dayL-W commented 6 months ago

Reminder

Reproduction

我的修改

为保证Reward Model出来的分数是有意义的,并且是归一化到0-1的,训练过程中替换了自己的Loss函数

loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()

        loss += -(torch.log(chosen_trunc_rewards) + torch.log(1 - rejected_trunc_rewards)).mean()

src/train_bash.py 推理效果

训练之后,调用src/train_bash.py做推理,出来的分数达到了我的预期 {"chosen": 0.79, "rejected": 0.01} {"chosen": 0.91, "rejected": 0.06} {"chosen": 0.78, "rejected": 0.05}

推理脚本为 deepspeed --master_port 29501 --num_gpus 2 src/train_bash.py \ --deepspeed ds_config_zero3.json \ --stage rm \ --do_predict \ --model_name_or_path /disk1/wuwen/models/checkpoint/rm_my_data_qwen_7b_lr5e6_freeze_sigmoid \ --dataset my_rm_test_zh \ --dataset_dir /disk1/wuwen/LLaMA-Factory/data_process/data/ \ --template default \ --output_dir path_to_predict_result \ --per_device_eval_batch_size 1 \ --max_samples 1000 \ --fp16

src/api_demo.py推理效果

使用api_demo.py做单条推理,出来的score没有归一化到0-1,和预期不符 { "id": "scoreeval-default", "object": "score.evaluation", "model": "qwen", "scores": [ 1.9638671875, -0.34716796875 ] }

运行脚本为 CUDA_VISIBLE_DEVICES=0 python src/api_demo.py \ --stage rm \ --model_name_or_path /disk1/wuwen/models/checkpoint/rm_my_data_qwen_7b_lr5e6_freeze_sigmoid \ --template default

我的推断

排查后,2个脚本使用的模型确保是一致的,stage也是一致。唯一的区别是

但是trainer.predict的底层也是使用output = self.model(**inputs)进行推理的,没搞清楚为什么量纲不一致

Expected behavior

No response

System Info

No response

Others

No response

dayL-W commented 6 months ago

发现了,自己代码BUg

yaopanyaopan commented 1 month ago

发现了,自己代码BUg

大佬你好,想问一下。src/api_demo.py推理效果的时候,你的数据请求格式是什么样子的?

yaopanyaopan commented 1 month ago

发现了,自己代码BUg

大佬你好,想问一下。src/api_demo.py推理效果的时候,你的数据请求格式是什么样子的?

payload = json.dumps({
    "model": model_path,
    "messages": [prompt],
    "max_length": 0
})messages部分需要的 history 和 instrcution,候选回复 需要做拼接么?