lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.08k stars 141 forks source link

Inverse of renormalization matrix being used? #13

Closed sidnarayanan closed 3 years ago

sidnarayanan commented 3 years ago

Hi! Thanks for the speedy reference implementation in Pytorch.

I noticed something odd: I think you're multiplying by D (in eq 4) not D^-1 in linear_attention:

def linear_attention(q, k, v):
    D_inv = torch.einsum('...nd,...d->...n', q, k.sum(dim = -2))
    context = torch.einsum('...nd,...ne->...de', k, v)
    out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
    return out

The tensor named D_inv is actually D. Unless I'm missing something obvious, which hopefully I am.

[tagging @ncilfone because we discussed this recently]

lucidrains commented 3 years ago

:bug: :bee: :beetle: :pray: