AlibabaResearch / DAMO-ConvAI

DAMO-ConvAI: The official repository which contains the codebase for Alibaba DAMO Conversational AI.
MIT License
1.1k stars 178 forks source link

PRO: SFT loss为什么没有在token上做平均呢 #53

Closed yingjianling closed 1 year ago

yingjianling commented 1 year ago

process_manager.py文件中计算loss的代码如下:

        sum_scores = torch.cat(score_list, dim=1) #[batch, training_stage]
        suffix_mask = torch.cat(suffix_mask_list, dim=1) #[batch, training_stage]
        scores = sum_scores / suffix_mask #[batch, training_stage]
        total_loss = 0
        for time in range(temp_training_stage - 1):
            neg_reward = batch["rewards"][:, time+1:] # [batch, training_stage-time-1]
            pos_reward = batch["rewards"][:, time] # [batch]

            eps = 1e-10
            neg_temperatures = pos_reward.view(-1, 1) - neg_reward # [batch, training_stage-time-1]
            pos_temperature = torch.max(neg_temperatures, dim=1).values # [batch]
            loss = torch.log(eps + torch.exp(scores[:, time] * pos_temperature) + torch.sum(torch.exp(scores[:, time+1:] * neg_temperatures), dim=1)) - scores[:, time] * pos_temperature # [batch]
            loss = torch.mean(loss).to(local_outputs.hidden_states[0].dtype)

            print_loss[time].append(loss.item())
            total_loss += loss

        sft_index = batch["sft_index"].view(batch_size, 1)
        sft_scores = torch.gather(input = sum_scores, dim = 1, index = sft_index).view(batch_size) #[batch]
        sft_loss = torch.mean(-sft_scores).to(local_outputs.hidden_states[0].dtype)
        sft_loss = args.sft_weight * math.pow(temp_training_stage - 1, 2) * sft_loss
        total_loss += sft_loss

其中sft_loss部分gather的对象是sum_scores而不是scores,可以说下原因吗,实测下来即使beta=0.05,sft_loss也要比rank_loss大一个数量级。

F2-Song commented 1 year ago

Hello!

感谢复现和支持PRO!

scores和sum_scores其实都试过,但用scores的BLEU一直不够高,各种尝试后发现,是在我们的设置下取sum最有效😂

yingjianling commented 1 year ago

谢谢答复,另外还有两个问题想问下:

  1. Table2中HH-RLHF alpaca,3和HH-RLHF chatgpt,3的BoN方法是3选1吗
  2. 你们训完的模型会不会偶尔出现相同的response不断repeat的情况 非常感谢
F2-Song commented 1 year ago

客气了~

  1. HH-RLHF{Alpaca,3}和HH-RLHF{ChatGPT,3}的BoN算法会使用三个candidate中最好的candidate来训练,“最好”则是根据我们使用的RM_{train}的打分最高来决定的。

  2. 相同的response不断repeat的情况,印象里PRO还算出现得比较少。可以考虑在生成时通过no_repeat_ngram或者repeat_penalty来缓解一下;但更根本的,个人认为需要在训练时注意让last token预测的next token是EOS,且这一步预测需要纳入到loss计算中。

供参考~

yingjianling commented 1 year ago

第二个问题我们也考虑到的,我们用的是下一轮的开头来替代EOS,起到终止本轮对话的作用,比如”\n### Human:“ ,自然也是加入到loss计算中的。这个方法在我们以前的实验中都可以有效防止这个现象,但是不知道是不是数据集的原因,我们在HH数据集上只做SFT都会出现这个问题。谢谢答复。