Closed kircle888 closed 1 week ago
你的PR提交成功,感谢你对开源项目的贡献! 请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。 Your PR has been submitted. Thanks for your contribution! Please wait for the result of CI firstly. See Paddle CI Manual for details.
Sorry to inform you that 9a130e6's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.
PR Category
Performance Optimization
PR Types
Performance
Description
FlashAttention支持QKVPacked输入和Padded Varlen输入,以避免几种情况下的预处理和后处理开销 nn.functional.flash_attention增加flash_attn_qkvpacked和flash_attn_varlen_qkvpacked flash_attn_qkvpacked接受5维输入qkv,形状为[batchsize, seqlen , num_heads/num_heads_k + 2, num_heads_k, head_dim] flash_attn_varlen_qkvpacked接受4维输入qkv,形状为[total_seq_len, num_heads/num_heads_k + 2, num_heads_k, head_dim],当参数varlen_padded为False时,输入输出为unpad形式(与flash_attn_unpadded相似),当varlen_padded为True时,输入输出为padded形式(即可以直接将flash_attn_qkvpacked的输入batchsize和seqlen维度flatten得到total_seq_len维度)