Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14k stars 1.3k forks source link

Error in Algorithm 1 of Flash Attention 2 paper #991

Open mbchang opened 4 months ago

mbchang commented 4 months ago

On line 10 of Algorithm 1 (FlashAttention-2 forward pass) of the Flash Attention 2 paper it says image

However, $\text{diag}\left(e^{m_i^{j-1} - m_i^{j}}\right)^{-1}\mathbf{O}_i^{(j-1)}$ should not have the $^{-1}$ and actually just be $\text{diag}\left(e^{m_i^{j-1} - m_i^{j}}\right)\mathbf{O}_i^{(j-1)}$.

Similar the online softmax trick example at the top of page 6 should also not have the $^{-1}$ before the $\tilde{\mathbf{O}}^{(1)}$ image

Lastly, there is a typo, where the following line image should be deleted from the online softmax trick derivation in page 6. It looks like it was copied from the online softmax trick derivation on page 4 for Flash Attention 1 and was not removed.

xyg-coder commented 1 month ago

I was looking at triton implementation and it matches your idea. Thanks for pointing this out.