Closed ShomyLiu closed 2 years ago
@ShomyLiu thank you for catching this! https://github.com/lucidrains/FLASH-pytorch/commit/0cb9473dc1f839ce1c61fd0dacb7fcc161ea3b1e
When I read this part of the expressions and formulas, it should be that the reduction is the group dimension.
Hi,
It indeed that there is a reduction for all groups. However, in the final page Code 8: Pseudocode for FLASH
, there is no reduction for groups.
So maybe both are OK.
(In my opinion, if there is a sum reduction for all groups, the attention results would be quite larger than the quad_part?)
Nice, thanks again for your work~
Hi, Thanks a lot for your FLASH_pytorch, which helps a lot. I found that there are some differences from the paper in the Linear Attention Part: https://github.com/lucidrains/FLASH-pytorch/blob/main/flash_pytorch/flash_pytorch.py#L342-L343
the
lin_kv
is three-dim (bde
) And the code in the paper isthe
lin_kv
is four-dim (bgke
) It seems that the two ways are not equivalent.Looking forward to your reply. Best,