Closed szhsjsu closed 1 month ago
您好~
我把模型改成了Qwen2,重写了attn那里,用sst2数据调用没问题,但是loss总会在第一个batch后变成nan,lr调整过到很小也是一样,debug后看到是 RuntimeError: Function 'MulBackward0' returned nan values in its 1th output.
追踪到是反向到这里时,attn_weights变空了
class AttentionAdapter(AttentionAdapterBase): def init(self, n_demo, n_head, device) -> None: super().init() self.n_demo = n_demo self.n_head = n_head self.weight = torch.nn.Parameter( torch.zeros((n_head, n_demo), requires_grad=True, device=device)) self.class_poss = None self.final_poss = None
def _forward(self, attn_weights): class_poss = self.class_poss final_poss = self.final_poss weight = self.weight.exp() bsz, n_head, seq_len, _ = attn_weights.shape assert bsz == 1 mask_mat = torch.ones((1, n_head, seq_len, seq_len), device=attn_weights.device) mask_mat[:, :, final_poss, class_poss] = weight.reshape(1, self.n_head, self.n_demo) return attn_weights * mask_mat
求问,这里有什么解决方案吗
是不是因为实际调了flash_attn所以没有attn_weights?
找到问题了,是int4加载导致后向计算溢出了~感谢您回复的这么快,reweight真的有用!
您好~
我把模型改成了Qwen2,重写了attn那里,用sst2数据调用没问题,但是loss总会在第一个batch后变成nan,lr调整过到很小也是一样,debug后看到是 RuntimeError: Function 'MulBackward0' returned nan values in its 1th output.
追踪到是反向到这里时,attn_weights变空了
class AttentionAdapter(AttentionAdapterBase): def init(self, n_demo, n_head, device) -> None: super().init() self.n_demo = n_demo self.n_head = n_head self.weight = torch.nn.Parameter( torch.zeros((n_head, n_demo), requires_grad=True, device=device)) self.class_poss = None self.final_poss = None
求问,这里有什么解决方案吗