GanjinZero / RRHF

[NIPS2023] RRHF & Wombat
780 stars 49 forks source link

RRHFTrainer.gather_logits_labels label in-place operation error #37

Open asadfgglie opened 11 months ago

asadfgglie commented 11 months ago

image train.py中的RRHFTrainer.gather_logits_labels原始的寫法如下:

mask = (labels != -100).float()
new_logits = logits.clone()  # Create a copy to avoid in-place modification
labels[labels == -100] = 0  # in-place error!
output = torch.gather(new_logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
output = output * mask # B * L
return output

這會導致 train.py line 274的get_score無法正確計算各個responses的length:

def get_score(self, logit_label, labels):
        mask = (labels != -100).float() # all elements are True
        length = mask.sum(-1)
        scores = logit_label.sum(-1) / (length ** self.args.length_penalty)
        return scores
GanjinZero commented 11 months ago

我之前没有遇到这个bug,你改了batch size还是什么吗

asadfgglie commented 11 months ago

我沒有改過batch size 這是你最新的https://github.com/GanjinZero/RRHF/commit/edf17648c115c94be93e10c5fd0a7128aafc95f5 commit的bug

GanjinZero commented 11 months ago

你的python版本?error是什么

asadfgglie commented 11 months ago

python==3.9.6 torch==1.13.1+cu117

asadfgglie commented 11 months ago

執行時不會報error 這是一個runtime error 我開著debuger研究時發現的

asadfgglie commented 11 months ago

主要的效果是會導致再計算get_score時他的分母length會統一成最大長度的responses的長度 而不是各個responses的長度

asadfgglie commented 11 months ago

原始的labels[labels == -100] = 0會改到原始記憶體位址中的值,導致inputs['labels']的值跟著被更改,使得get_score計算有誤

GanjinZero commented 11 months ago

感谢反馈