Closed iclementine closed 5 months ago
support grouped query attention(GQA) for flash_attn(related kernels: fwd, bwd, split_kv, total_attention)
The MQA paper
Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.” arXiv, November 5, 2019. https://doi.org/10.48550/arXiv.1911.02150.
The GQA paper
Ainslie, Joshua, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit Sanghai. “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” arXiv, December 23, 2023.
Mind the layout of the heads in the query.
support grouped query attention(GQA) for flash_attn(related kernels: fwd, bwd, split_kv, total_attention)
The MQA paper
The GQA paper
Mind the layout of the heads in the query.