lucidrains / FLASH-pytorch

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
MIT License
344 stars 24 forks source link

mask error #1

Closed keyunluo closed 2 years ago

keyunluo commented 2 years ago
x = torch.randint(0, 20000, (1, 1024))
mask = x.ne(0)
logits = model(x, mask=mask)

RuntimeError: The size of tensor a (1024) must match the size of tensor b (128) at non-singleton dimension 2

lucidrains commented 2 years ago

@keyunluo oops! should be fixed here https://github.com/lucidrains/FLASH-pytorch/commit/8e0d2fd7925c0de9703d666ea2cc004327f6e544

marsggbo commented 1 year ago

In vanilla multi-head transformer, the mask is an upper triangle matrix and the shape of mask is [bs, n_head, seq, seq]. But here, the shape of mask is [bs, seq]. Is any explanation for this? How do we set the mask for mixed chunk attention, performing similarly to the upper triangle matrix?

lucidrains commented 1 year ago

@marsggbo just set causal = True and everything is taken care of

lucidrains commented 1 year ago

@marsggbo the mask is for non-causal variable lengthed sequences

marsggbo commented 1 year ago

@lucidrains thanks for your quick reply. I double check the code and understand that the causal_mask will be created when causal==True and everything is well taken care of.

However, I'm curious about these parts below. https://github.com/lucidrains/FLASH-pytorch/blob/473ab6209fae8711f1678e8a0a1e112262e4600a/flash_pytorch/flash_pytorch.py#L317-L319

https://github.com/lucidrains/FLASH-pytorch/blob/473ab6209fae8711f1678e8a0a1e112262e4600a/flash_pytorch/flash_pytorch.py#L352-L353

In summary, my two question is: By default, mask=None, what is this variable for and what is the difference between mask and causal_mask?

lucidrains commented 1 year ago

mask is for bidirectional attention (like BERT)

you often have sequences of variable lengths, so you would use the mask to only attend to non-padding tokens