66RING / tiny-flash-attention

flash attention tutorial written in python, triton, cuda, cutlass
194 stars 14 forks source link

请教decoding阶段计算浪费的问题 #8

Open sleepwalker2017 opened 1 month ago

sleepwalker2017 commented 1 month ago

目前 decoding 阶段的计算浪费是和什么有关? 是不是kBlockM 这个值决定的?

在 Flash decoding 阶段,我看到的kBlockM = 64,所以计算的浪费是63/64,是这样吗?

好奇为什么不把这个值设置为 16 呢?这样依然可以利用 Tensor Core,而且浪费的算力还少很多。

感谢!

66RING commented 1 month ago

@sleepwalker2017 decoding阶段已经不是compute bound了,而是memory bound.

即从gmem加载到smem再加载到reg这个过程比矩阵乘法还要长,时间都浪费在io上了。

而原版的flash attention一个q block是一个并行度。在decoding时只需要处理一个query就只有一个并行度在执行,然后沿着sequence len执行for循环加载k v. 所以flash decoding在处理长序列时把一个query的处理分成多个block并行处理, 把这个for循环并行化了。

我还没看过flash decoding部分的代码,应该是有调整分块大小这部分的优化的,好像叫flash_fwd_splitkv_kernel官方flash attention用了很多模板就是对各种情况特化出最优的分块大小。

sleepwalker2017 commented 1 month ago

@sleepwalker2017 decoding阶段已经不是compute bound了,而是memory bound.

即从gmem加载到smem再加载到reg这个过程比矩阵乘法还要长,时间都浪费在io上了。

而原版的flash attention一个q block是一个并行度。在decoding时只需要处理一个query就只有一个并行度在执行,然后沿着sequence len执行for循环加载k v. 所以flash decoding在处理长序列时把一个query的处理分成多个block并行处理, 把这个for循环并行化了。

我还没看过flash decoding部分的代码,应该是有调整分块大小这部分的优化的,好像叫flash_fwd_splitkv_kernel官方flash attention用了很多模板就是对各种情况特化出最优的分块大小。

分块的大小,我初步看,只是限制了 64 和 128 两种可选。

我测试了batch=128, seq_len=256的情况下,flash Attention 的 decoding 阶段 在不同 BlockM 和 BlockN配置下的耗时对比 image

看起来,这个block 配置对性能的影响还是挺明显的。我再研究研究。

66RING commented 1 month ago

@sleepwalker2017 这个seqlen可能太小了。batch size虽然大但是影响能力不一样。seqlen是平方级的影响。计算开销大概是bs seqlen seqlen这样的

66RING commented 1 month ago

@sleepwalker2017 刚看了一下,官方的这个位置是有flash decoding分块大小的优化的。可以看那段注释

image

sleepwalker2017 commented 1 month ago

@sleepwalker2017 刚看了一下,官方的这个位置是有flash decoding分块大小的优化的。可以看那段注释

image

还真是!这个分块很小了!我学习下,感谢!

sleepwalker2017 commented 1 month ago

@sleepwalker2017 这个seqlen可能太小了。batch size虽然大但是影响能力不一样。seqlen是平方级的影响。计算开销大概是bs seqlen seqlen这样的

我理解 decoding 阶段的开销应是线性的吧?Prefill 是平方的

66RING commented 1 month ago

@sleepwalker2017 这个seqlen可能太小了。batch size虽然大但是影响能力不一样。seqlen是平方级的影响。计算开销大概是bs seqlen seqlen这样的

我理解 decoding 阶段的开销应是线性的吧?Prefill 是平方的

@sleepwalker2017 这和q k长度有关。decoding时q长度是1所以线性

sleepwalker2017 commented 1 month ago

@sleepwalker2017 刚看了一下,官方的这个位置是有flash decoding分块大小的优化的。可以看那段注释

image

我看了下,这块不是 Attention kernel,是负责把不同的 split 合并到一起的 reduce 操作。 也就是说当split_num > 1的时候,会 launch 两个 kernel。 前面那个 kernel 还是 64 或者 128 的分块。 后面的 reduce 操作才用的小的分块。

我理解要做GEMM,如果是mma16x8x16指令,最少也得用 16 的kBlockM。 本来想改下这块,但是看了里面的mma 之类的原语还不太熟悉,看起来有很多 assert,限制了一个 block 里的 threads必须是 128 或者 64。

sleepwalker2017 commented 1 month ago

image 用 Nsight 看了下,确实是

sleepwalker2017 commented 1 month ago

image 我测了下不同的kBlockM/kBlockN的配置,四列是不同的配置。

在小 batch 下,应该是 SM 资源没打满,所以都能一次性调度,性能没啥差别。 在大 batch 下,我理解kBlockM=128 的情况可能是浪费了 shared memory 和寄存器资源,导致并行度下降了。 看起来性能差距还挺明显的。 如果能进一步降低kBlockM的值,应该还能增加并行度。 @66RING

66RING commented 1 month ago

可以的 你发现了华点。不过这个在不同设备上的情况可能不同。要一个配置适配所有设备很难

sleepwalker2017 commented 1 month ago

可以的 你发现了华点。不过这个在不同设备上的情况可能不同。要一个配置适配所有设备很难

嗯嗯 通用的不好搞 专门在一个设备上调优可能会有些收益

eljrte commented 1 month ago

可以的 你发现了华点。不过这个在不同设备上的情况可能不同。要一个配置适配所有设备很难

嗯嗯 通用的不好搞 专门在一个设备上调优可能会有些收益

您好,能留个联系方式不?我暑假也在研究flash-attention系列推理的东西,想交流一下

sleepwalker2017 commented 1 month ago

加你了

------------------ 原始邮件 ------------------ 发件人: eljrte @.> 发送时间: 2024年9月25日 19:42 收件人: 66RING/tiny-flash-attention @.> 抄送: fade_away @.>, Mention @.> 主题: Re: [66RING/tiny-flash-attention] 请教decoding阶段计算浪费的问题 (Issue #8)

可以的 你发现了华点。不过这个在不同设备上的情况可能不同。要一个配置适配所有设备很难

嗯嗯 通用的不好搞 专门在一个设备上调优可能会有些收益

您好,能留个联系方式不?我暑假也在研究flash-attention系列推理的东西,想交流一下

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.Message ID: @.***>

A-transformer commented 3 weeks ago

image 我测了下不同的kBlockM/kBlockN的配置,四列是不同的配置。

在小 batch 下,应该是 SM 资源没打满,所以都能一次性调度,性能没啥差别。 在大 batch 下,我理解kBlockM=128 的情况可能是浪费了 shared memory 和寄存器资源,导致并行度下降了。 看起来性能差距还挺明显的。 如果能进一步降低kBlockM的值,应该还能增加并行度。 @66RING

你在什么配置的机器上做的,能说一下机器配置吗,SM80? GPU的具体型号,是否用过 cupti 打点做过细粒度分析?

sleepwalker2017 commented 3 weeks ago

image 我测了下不同的kBlockM/kBlockN的配置,四列是不同的配置。 在小 batch 下,应该是 SM 资源没打满,所以都能一次性调度,性能没啥差别。 在大 batch 下,我理解kBlockM=128 的情况可能是浪费了 shared memory 和寄存器资源,导致并行度下降了。 看起来性能差距还挺明显的。 如果能进一步降低kBlockM的值,应该还能增加并行度。 @66RING

你在什么配置的机器上做的,能说一下机器配置吗,SM80? GPU的具体型号,是否用过 cupti 打点做过细粒度分析?

H100 , cupti 我不熟悉,没有用这个做过分析