PaddlePaddle / Paddle

PArallel Distributed Deep LEarning: Machine Learning Framework from Industrial Practice (『飞桨』核心框架,深度学习&机器学习高性能单机、分布式训练和跨平台部署)
http://www.paddlepaddle.org/
Apache License 2.0
21.66k stars 5.44k forks source link

Flashattention support qkvpacked and varlen #63289

Closed kircle888 closed 1 week ago

kircle888 commented 1 month ago

PR Category

Performance Optimization

PR Types

Performance

Description

FlashAttention支持QKVPacked输入和Padded Varlen输入,以避免几种情况下的预处理和后处理开销 nn.functional.flash_attention增加flash_attn_qkvpacked和flash_attn_varlen_qkvpacked flash_attn_qkvpacked接受5维输入qkv,形状为[batchsize, seqlen , num_heads/num_heads_k + 2, num_heads_k, head_dim] flash_attn_varlen_qkvpacked接受4维输入qkv,形状为[total_seq_len, num_heads/num_heads_k + 2, num_heads_k, head_dim],当参数varlen_padded为False时,输入输出为unpad形式(与flash_attn_unpadded相似),当varlen_padded为True时,输入输出为padded形式(即可以直接将flash_attn_qkvpacked的输入batchsize和seqlen维度flatten得到total_seq_len维度)

paddle-bot[bot] commented 1 month ago

你的PR提交成功,感谢你对开源项目的贡献! 请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。 Your PR has been submitted. Thanks for your contribution! Please wait for the result of CI firstly. See Paddle CI Manual for details.

CLAassistant commented 1 month ago

CLA assistant check
All committers have signed the CLA.

paddle-ci-bot[bot] commented 1 month ago

Sorry to inform you that 9a130e6's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

GuoxiaWang commented 3 weeks ago

API doc https://github.com/PaddlePaddle/docs/pull/6608