FlagOpen / FlagAttention

A collection of memory efficient attention operators implemented in the Triton language.
Other
213 stars 13 forks source link

support grouped query attention(MQA & GQA) for flash_attn #22

Closed iclementine closed 5 months ago

iclementine commented 5 months ago

support grouped query attention(GQA) for flash_attn(related kernels: fwd, bwd, split_kv, total_attention)

The MQA paper

Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.” arXiv, November 5, 2019. https://doi.org/10.48550/arXiv.1911.02150.

The GQA paper

Ainslie, Joshua, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit Sanghai. “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” arXiv, December 23, 2023.

Mind the layout of the heads in the query.