lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.08k stars 141 forks source link

Applying decoder input mask? #51

Closed maxmax1992 closed 3 years ago

maxmax1992 commented 3 years ago

Hi, I'm trying to implement basic transformer architecture "Attention is all you need", but replacing MultiHeadAttention with the performer_pytorch.SelfAttention, however the expected mask for decoder input is apparently not of shape n x n? I've tried different setups, but no success. Any tips/ideas? I've only glanced through the paper.

lucidrains commented 3 years ago

@maxmax1992 Hi Maxim! You do not have to worry about passing in the NxN triangular mask for decoder. Simply set the causal keyword argument to True and it will be all taken care of!

maxmax1992 commented 3 years ago

Closing this as the solution is to pass causal=True to SelfAttention class.