lancopku / label-words-are-anchors

Repository for Label Words are Anchors: An Information Flow Perspective for Understanding In-Context Learning
MIT License
144 stars 12 forks source link

训练loss问题 #27

Closed szhsjsu closed 1 month ago

szhsjsu commented 1 month ago

您好~

我把模型改成了Qwen2,重写了attn那里,用sst2数据调用没问题,但是loss总会在第一个batch后变成nan,lr调整过到很小也是一样,debug后看到是 RuntimeError: Function 'MulBackward0' returned nan values in its 1th output.

追踪到是反向到这里时,attn_weights变空了

class AttentionAdapter(AttentionAdapterBase): def init(self, n_demo, n_head, device) -> None: super().init() self.n_demo = n_demo self.n_head = n_head self.weight = torch.nn.Parameter( torch.zeros((n_head, n_demo), requires_grad=True, device=device)) self.class_poss = None self.final_poss = None

def _forward(self, attn_weights):
    class_poss = self.class_poss
    final_poss = self.final_poss
    weight = self.weight.exp()
    bsz, n_head, seq_len, _ = attn_weights.shape
    assert bsz == 1
    mask_mat = torch.ones((1, n_head, seq_len, seq_len), device=attn_weights.device)
    mask_mat[:, :, final_poss, class_poss] = weight.reshape(1, self.n_head, self.n_demo)
    return attn_weights * mask_mat

求问,这里有什么解决方案吗

leanwang326 commented 1 month ago

是不是因为实际调了flash_attn所以没有attn_weights?

szhsjsu commented 1 month ago

是不是因为实际调了flash_attn所以没有attn_weights?

找到问题了,是int4加载导致后向计算溢出了~感谢您回复的这么快,reweight真的有用!