XuezheMax / megalodon

Reference implementation of Megalodon 7B model
MIT License
487 stars 50 forks source link

Flash Attention V2 vs Megalodon Swift Attention #4

Closed timmytwoteeth closed 2 months ago

timmytwoteeth commented 2 months ago

Hi all,

Thank you for the great work.

Is there any comparison between the performance of Flash Attention 2 vs Megalodon Swift Attention?

It is not apparent to me why an individual would choose Swift Attention over Flash Attention 2 or PyTorch SDPA FA2.

violet-zct commented 2 months ago

Hi,

The Swift Attention in this repo provides us with a fused implementation of softmax function, causal mask and dropout operation to reduce the IO costs. It supports different dimension size of Q, K, and V. Since we assume the chunk-wise attention setting where the chunk size is not large (4096 in our experiments), we did not use the block-wise attention computation in Flash Attention, but standard matrix multiplication.

In summary, if your Q, K, V length is not long and/or you want to support more flexible shape of QKV, Swift Attention is faster than Flash Attention. We also support the fused dropout with the DropKey pattern. Please see Appendix B for more details.

Thanks.

timmytwoteeth commented 2 months ago

Thank you for the input!

The comment on sequence length is helpful.