pmixer / SASRec.pytorch

PyTorch(1.6+) implementation of https://github.com/kang205/SASRec
Apache License 2.0
349 stars 93 forks source link

why is the attention_mask's shape (tl, tl) #30

Open rabbicat30 opened 1 year ago

rabbicat30 commented 1 year ago
tl = seqs.shape[1]  # time dim len for enforce causality

attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.dev))

I can't understand why the attention_mask is this shape. Can you give me an answer or some references? I would be very grateful for your help!

seanswyi commented 1 year ago

You should look at the original Transformer paper and other blog posts (e.g., The Illustrated Transformer is great) for some more information. The reason is because in self-attention we're performing attention on a tensor with itself, hence the square shape.

rabbicat30 commented 1 year ago

I know it. Thanks very much!

You should look at the original Transformer paper and other blog posts (e.g., The Illustrated Transformer is great) for some more information. The reason is because in self-attention we're performing attention on a tensor with itself, hence the square shape.