GanjinZero / RRHF

[NIPS2023] RRHF & Wombat
780 stars 49 forks source link

算loss的时候求均值的时候是不是可以优化 #51

Open shyoulala opened 6 months ago

shyoulala commented 6 months ago

我看到在sft_loss 的时候直接求了平均,平均的分母是样本label的长度,包括不参与训练的,是否应该采用mask mean 就像: item = -logit_label[max_idx] return -torch.sum(item)/ torch.sum(labels!=-100)。##因为在gather_logits_labels 这一步把-100的prob已经变成0了 而不是-logit_label[max_idx].mean()

image

GanjinZero commented 6 months ago

肯定可以的

IT-five commented 6 months ago

肯定可以的

我想在tensorboard中同时显示rrhf_loss和sft_loss的loss曲线,在哪里添加呀?

IT-five commented 6 months ago

这里的logit_label.sum(-1)是负数,那这里长度就不是惩罚了把,length_penalty如果设为2,那不是得分更高了。

def get_score(self, logit_label, labels):
        mask = (labels != -100).float()
        length = mask.sum(-1) # 当前的seq_length
        scores = logit_label.sum(-1) / (length ** self.args.length_penalty) # shape=torch.size([5])
        return scores
GanjinZero commented 6 months ago

length永远是正的

IT-five commented 6 months ago

length永远是正的

是啊,length是正的,但是logit_label.sum(-1)是经过F.log_softmax()的,所以一定是负数,那负数➗更大的值,不是scores会增大吗?比如-2/1和-2/2

IT-five commented 6 months ago
logits = model(input_ids=inputs.get('input_ids'), attention_mask=inputs.get('attention_mask'))[0] # (batch * cand) * L * V
logits = F.log_softmax(logits, dim=-1)
logit_label = self.gather_logits_labels(logits, inputs.get("labels"))
scores = self.get_score(logit_label, inputs.get("labels"))