hiyouga / LLaMA-Factory

Unified Efficient Fine-Tuning of 100+ LLMs (ACL 2024)
https://arxiv.org/abs/2403.13372
Apache License 2.0
35.28k stars 4.35k forks source link

DPO 训练 后输出重复问题 #1347

Closed Cloopen-ReLiNK closed 11 months ago

Cloopen-ReLiNK commented 1 year ago

v100 qwen模型 dpo训练后模型输出一直重复,还出各种乱码及其他语种的东西 数据使用的comparison_gpt4和oaast_rm

tmacsyf commented 1 year ago

请问如何解决的?我的也出现了,求教,谢谢

RavidLightricks commented 1 year ago

How do you run prediction?

MissQueen commented 1 year ago

v100 qwen模型 dpo训练后模型输出一直重复,还出各种乱码及其他语种的东西 数据使用的comparison_gpt4和oaast_rm

我用llama2训练也是重复加乱码,有人知道怎么回事吗

lylcst commented 1 year ago

使用hh数据集,qwen-7b出现重复

kyriekevin commented 1 year ago

请教一下有解决的大佬嘛?我用huggingface trl中dpo示例代码跑qwen(8张A100 80G),也都会出现这样的情况,是需要什么特殊配置嘛?还是qwen的dpo就是特别难训?

lylcst commented 1 year ago

我这边试了使用dpo全量训练bloomz-7b也会出现回复重复,尝试使用lora训练,或者加一个ft loss才可以基本解决重复问题。。。。

vip-china commented 11 months ago

我这边试了使用dpo全量训练bloomz-7b也会出现回复重复,尝试使用lora训练,或者加一个ft loss才可以基本解决重复问题。。。。

请问ft loss如何添加,我也面临这个问题,dpo训练后,回答胡言乱语和乱码,谢谢

lylcst commented 11 months ago

我这边试了使用dpo全量训练bloomz-7b也会出现回复重复,尝试使用lora训练,或者加一个ft loss才可以基本解决重复问题。。。。

请问ft loss如何添加,我也面临这个问题,dpo训练后,回答胡言乱语和乱码,谢谢

我是在CustomDPOTrainer这个类中加了一个计算sft loss的方法

    def sft_loss(self, all_logits, labels):
        all_logps = self._get_batch_logps(
            all_logits,
            labels,
            average_log_prob=True
        )
        return -all_logps.mean()

然后将父类DPOTrainer中的get_batch_metrics方法在CustomDPOTrainer中修改重写了一下,在原有的loss基础上加了ft loss

    def get_batch_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
        ) = self.concatenated_forward(model, batch)
        with torch.no_grad():
            if self.ref_model is None:
                with self.accelerator.unwrap_model(self.model).disable_adapter():
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        _,
                        _,
                    ) = self.concatenated_forward(self.model, batch)
            else:
                (
                    reference_chosen_logps,
                    reference_rejected_logps,
                    _,
                    _,
                ) = self.concatenated_forward(self.ref_model, batch)

        losses, chosen_rewards, rejected_rewards = self.dpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps,
        )
       # 计算ft loss
        batch_size = batch['labels'].size(0) // 2
        sft_loss = self.sft_loss(policy_chosen_logits, batch['labels'].split(batch_size, dim=0)[0])

        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
        metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
        metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
        metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
        metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
        metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()

        return losses.mean() + 0.1*sft_loss, metrics
hiyouga commented 11 months ago

@lylcst 如果效果好可以作为一个新的训练模式加到框架里

lylcst commented 11 months ago

我这边实验是取得了一个不错的效果

yssAI commented 11 months ago

你这个就是增加了policy_chosen_logps的权重?

hiyouga commented 11 months ago

solved in b87c74289d523ef88611b376074199ffd03cf103

tibetgao commented 4 months ago

我遇到过一次此类问题,排查后发现是同时打开了 lora 和 adapter,关掉 adapter 只用 lora 之后问题消失了。DPO 的训练看起来需要将 lora 的参数量设置到远小于 sft,否则效果不一定好。

Tramac commented 4 months ago

看到通过添加 sft loss 来解决有个好奇的问题:既然添加 sft loss 可以解决,那是不是直接用 sft 来训练?如何评估性能提升是 dpo 带来的还是 sft 带来的呢?

Tramac commented 4 months ago

我遇到过一次此类问题,排查后发现是同时打开了 lora 和 adapter,关掉 adapter 只用 lora 之后问题消失了。DPO 的训练看起来需要将 lora 的参数量设置到远小于 sft,否则效果不一定好。

有点迷惑,关掉 adapter 只用 lora 是个什么操作呢?