Open nkkbr opened 2 months ago
For RoPEMaskedAttentionHead,
if return_attn_weights: attn_mask = torch.tril(torch.ones((m,m)), diagonal=0) attn_weights = torch.bmm(q_rotated, k_rotated.transpose(1,2)) / np.sqrt(d) + attn_mask attn_weights = F.softmax(attn_weights, dim=-1) return activations, attn_weights return activations
I think it should be
if return_attn_weights: attn_mask = torch.tril(torch.ones((m,m)),diagonal=0) attn_mask = torch.where(attn_mask==1,torch.tensor(0),torch.tensor(float('-inf'))) attn_weights = torch.bmm(q_rotated,k_rotated.transpose(1,2)) / np.sqrt(d) + attn_mask attn_weights = F.softmax(attn_weights,dim=-1) return activations,attn_weights return activations
For RoPEMaskedAttentionHead,
I think it should be