Open asadfgglie opened 11 months ago
我之前没有遇到这个bug,你改了batch size还是什么吗
我沒有改過batch size 這是你最新的https://github.com/GanjinZero/RRHF/commit/edf17648c115c94be93e10c5fd0a7128aafc95f5 commit的bug
你的python版本?error是什么
python==3.9.6 torch==1.13.1+cu117
執行時不會報error 這是一個runtime error 我開著debuger研究時發現的
主要的效果是會導致再計算get_score
時他的分母length
會統一成最大長度的responses的長度
而不是各個responses的長度
原始的labels[labels == -100] = 0
會改到原始記憶體位址中的值,導致inputs['labels']
的值跟著被更改,使得get_score
計算有誤
感谢反馈
這會導致 train.py line 274的get_score無法正確計算各個responses的length: