Open sleepwalker2017 opened 1 month ago
@sleepwalker2017 decoding阶段已经不是compute bound了,而是memory bound.
即从gmem加载到smem再加载到reg这个过程比矩阵乘法还要长,时间都浪费在io上了。
而原版的flash attention一个q block是一个并行度。在decoding时只需要处理一个query就只有一个并行度在执行,然后沿着sequence len执行for循环加载k v. 所以flash decoding在处理长序列时把一个query的处理分成多个block并行处理, 把这个for循环并行化了。
我还没看过flash decoding部分的代码,应该是有调整分块大小这部分的优化的,好像叫flash_fwd_splitkv_kernel
官方flash attention用了很多模板就是对各种情况特化出最优的分块大小。
@sleepwalker2017 decoding阶段已经不是compute bound了,而是memory bound.
即从gmem加载到smem再加载到reg这个过程比矩阵乘法还要长,时间都浪费在io上了。
而原版的flash attention一个q block是一个并行度。在decoding时只需要处理一个query就只有一个并行度在执行,然后沿着sequence len执行for循环加载k v. 所以flash decoding在处理长序列时把一个query的处理分成多个block并行处理, 把这个for循环并行化了。
我还没看过flash decoding部分的代码,应该是有调整分块大小这部分的优化的,好像叫
flash_fwd_splitkv_kernel
官方flash attention用了很多模板就是对各种情况特化出最优的分块大小。
分块的大小,我初步看,只是限制了 64 和 128 两种可选。
我测试了batch=128, seq_len=256的情况下,flash Attention 的 decoding 阶段 在不同 BlockM 和 BlockN配置下的耗时对比
看起来,这个block 配置对性能的影响还是挺明显的。我再研究研究。
@sleepwalker2017 这个seqlen可能太小了。batch size虽然大但是影响能力不一样。seqlen是平方级的影响。计算开销大概是bs seqlen seqlen这样的
@sleepwalker2017 刚看了一下,官方的这个位置是有flash decoding分块大小的优化的。可以看那段注释
@sleepwalker2017 刚看了一下,官方的这个位置是有flash decoding分块大小的优化的。可以看那段注释
还真是!这个分块很小了!我学习下,感谢!
@sleepwalker2017 这个seqlen可能太小了。batch size虽然大但是影响能力不一样。seqlen是平方级的影响。计算开销大概是bs seqlen seqlen这样的
我理解 decoding 阶段的开销应是线性的吧?Prefill 是平方的
@sleepwalker2017 这个seqlen可能太小了。batch size虽然大但是影响能力不一样。seqlen是平方级的影响。计算开销大概是bs seqlen seqlen这样的
我理解 decoding 阶段的开销应是线性的吧?Prefill 是平方的
@sleepwalker2017 这和q k长度有关。decoding时q长度是1所以线性
@sleepwalker2017 刚看了一下,官方的这个位置是有flash decoding分块大小的优化的。可以看那段注释
我看了下,这块不是 Attention kernel,是负责把不同的 split 合并到一起的 reduce 操作。 也就是说当split_num > 1的时候,会 launch 两个 kernel。 前面那个 kernel 还是 64 或者 128 的分块。 后面的 reduce 操作才用的小的分块。
我理解要做GEMM,如果是mma16x8x16指令,最少也得用 16 的kBlockM。 本来想改下这块,但是看了里面的mma 之类的原语还不太熟悉,看起来有很多 assert,限制了一个 block 里的 threads必须是 128 或者 64。
用 Nsight 看了下,确实是
我测了下不同的kBlockM/kBlockN的配置,四列是不同的配置。
在小 batch 下,应该是 SM 资源没打满,所以都能一次性调度,性能没啥差别。 在大 batch 下,我理解kBlockM=128 的情况可能是浪费了 shared memory 和寄存器资源,导致并行度下降了。 看起来性能差距还挺明显的。 如果能进一步降低kBlockM的值,应该还能增加并行度。 @66RING
可以的 你发现了华点。不过这个在不同设备上的情况可能不同。要一个配置适配所有设备很难
可以的 你发现了华点。不过这个在不同设备上的情况可能不同。要一个配置适配所有设备很难
嗯嗯 通用的不好搞 专门在一个设备上调优可能会有些收益
可以的 你发现了华点。不过这个在不同设备上的情况可能不同。要一个配置适配所有设备很难
嗯嗯 通用的不好搞 专门在一个设备上调优可能会有些收益
您好,能留个联系方式不?我暑假也在研究flash-attention系列推理的东西,想交流一下
加你了
------------------ 原始邮件 ------------------ 发件人: eljrte @.> 发送时间: 2024年9月25日 19:42 收件人: 66RING/tiny-flash-attention @.> 抄送: fade_away @.>, Mention @.> 主题: Re: [66RING/tiny-flash-attention] 请教decoding阶段计算浪费的问题 (Issue #8)
可以的 你发现了华点。不过这个在不同设备上的情况可能不同。要一个配置适配所有设备很难
嗯嗯 通用的不好搞 专门在一个设备上调优可能会有些收益
您好,能留个联系方式不?我暑假也在研究flash-attention系列推理的东西,想交流一下
— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.Message ID: @.***>
我测了下不同的kBlockM/kBlockN的配置,四列是不同的配置。
在小 batch 下,应该是 SM 资源没打满,所以都能一次性调度,性能没啥差别。 在大 batch 下,我理解kBlockM=128 的情况可能是浪费了 shared memory 和寄存器资源,导致并行度下降了。 看起来性能差距还挺明显的。 如果能进一步降低kBlockM的值,应该还能增加并行度。 @66RING
你在什么配置的机器上做的,能说一下机器配置吗,SM80? GPU的具体型号,是否用过 cupti 打点做过细粒度分析?
我测了下不同的kBlockM/kBlockN的配置,四列是不同的配置。 在小 batch 下,应该是 SM 资源没打满,所以都能一次性调度,性能没啥差别。 在大 batch 下,我理解kBlockM=128 的情况可能是浪费了 shared memory 和寄存器资源,导致并行度下降了。 看起来性能差距还挺明显的。 如果能进一步降低kBlockM的值,应该还能增加并行度。 @66RING
你在什么配置的机器上做的,能说一下机器配置吗,SM80? GPU的具体型号,是否用过 cupti 打点做过细粒度分析?
H100 , cupti 我不熟悉,没有用这个做过分析
目前 decoding 阶段的计算浪费是和什么有关? 是不是
kBlockM
这个值决定的?在 Flash decoding 阶段,我看到的
kBlockM = 64
,所以计算的浪费是63/64
,是这样吗?好奇为什么不把这个值设置为 16 呢?这样依然可以利用 Tensor Core,而且浪费的算力还少很多。
感谢!