Closed sidnarayanan closed 3 years ago
Hi! Thanks for the speedy reference implementation in Pytorch.
I noticed something odd: I think you're multiplying by D (in eq 4) not D^-1 in linear_attention:
linear_attention
def linear_attention(q, k, v): D_inv = torch.einsum('...nd,...d->...n', q, k.sum(dim = -2)) context = torch.einsum('...nd,...ne->...de', k, v) out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv) return out
The tensor named D_inv is actually D. Unless I'm missing something obvious, which hopefully I am.
D_inv
[tagging @ncilfone because we discussed this recently]
:bug: :bee: :beetle: :pray:
Hi! Thanks for the speedy reference implementation in Pytorch.
I noticed something odd: I think you're multiplying by D (in eq 4) not D^-1 in
linear_attention
:The tensor named
D_inv
is actually D. Unless I'm missing something obvious, which hopefully I am.[tagging @ncilfone because we discussed this recently]