bkitano / llama-from-scratch

Llama from scratch, or How to implement a paper without crying
https://blog.briankitano.com/llama-from-scratch/
482 stars 46 forks source link

RoPEMaskedAttentionHead #6

Open nkkbr opened 2 months ago

nkkbr commented 2 months ago

For RoPEMaskedAttentionHead,

        if return_attn_weights:
            attn_mask = torch.tril(torch.ones((m,m)), diagonal=0)
            attn_weights = torch.bmm(q_rotated, k_rotated.transpose(1,2)) / np.sqrt(d) + attn_mask
            attn_weights = F.softmax(attn_weights, dim=-1)
            return activations, attn_weights
        return activations

I think it should be

        if return_attn_weights:
            attn_mask = torch.tril(torch.ones((m,m)),diagonal=0)
            attn_mask = torch.where(attn_mask==1,torch.tensor(0),torch.tensor(float('-inf')))
            attn_weights = torch.bmm(q_rotated,k_rotated.transpose(1,2)) / np.sqrt(d) + attn_mask
            attn_weights = F.softmax(attn_weights,dim=-1)
            return activations,attn_weights
        return activations