lucidrains / FLASH-pytorch

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
MIT License
344 stars 24 forks source link

einsum operation in Linear Attention Part #2

Closed ShomyLiu closed 2 years ago

ShomyLiu commented 2 years ago

Hi, Thanks a lot for your FLASH_pytorch, which helps a lot. I found that there are some differences from the paper in the Linear Attention Part: https://github.com/lucidrains/FLASH-pytorch/blob/main/flash_pytorch/flash_pytorch.py#L342-L343

lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
lin_out = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)

the lin_kv is three-dim (bde) And the code in the paper is

lin_kv = tf.einsum('bhke,bgh→bgke', lin_kv, mask) 
linear = tf.einsum('bgnk,bgke→bgne', lin_q, lin_kv)

the lin_kv is four-dim (bgke) It seems that the two ways are not equivalent.

Looking forward to your reply. Best,

lucidrains commented 2 years ago

@ShomyLiu thank you for catching this! https://github.com/lucidrains/FLASH-pytorch/commit/0cb9473dc1f839ce1c61fd0dacb7fcc161ea3b1e

wangleiofficial commented 2 years ago

When I read this part of the expressions and formulas, it should be that the reduction is the group dimension.

image image

ShomyLiu commented 2 years ago

Hi, It indeed that there is a reduction for all groups. However, in the final page Code 8: Pseudocode for FLASH, there is no reduction for groups. So maybe both are OK. (In my opinion, if there is a sum reduction for all groups, the attention results would be quite larger than the quad_part?)

lucidrains commented 2 years ago

oh! ok, no problem https://github.com/lucidrains/FLASH-pytorch/commit/0a4e3e20478e0f2c545a8cb43e51fb5c2e6b3b42

ShomyLiu commented 2 years ago

Nice, thanks again for your work~