Closed Hprairie closed 2 months ago
Thanks for the detailed investigation! This is awesome.
Last time I profiled, the reverse cumsum was taking around 10-20us compared to hundreds of microseconds for the other kernels. Have you profiled the whole backward pass?
I will take a deeper look into profiling the whole kernel, I also will say that I don't know how much this also matters with torch.compile()
. I'll ping you if I find that there is a noticeable speedup, and draft a pr.
Hey Tri, after doing more profiling I can confirm that you are correct and that the overhead is essentially negligible and only roughly 20 us regardless of seqlen. Here is the benchmarking for the full bwd pass kernel with an in-place revsum vs. flipping.
Reverse Associative Scan Performance in Full Bwd Kernel of Mamba2:
seqlen Optimized Triton Kernel Naive Pytorch Flipping
0 32.0 0.214016 0.238592
1 64.0 0.236544 0.238592
2 128.0 0.222208 0.260096
3 256.0 0.237568 0.268288
4 512.0 0.274432 0.277504
5 1024.0 0.360448 0.363520
6 2048.0 0.654336 0.663552
7 4096.0 1.661952 1.672192
8 8192.0 3.486720 3.502080
It then seems not that beneficial to have a dedicated revsum kernel, even though it gives some slight performance gains. This is also in part due to Triton's issues with associative scan bugs.
I have been doing some tinkering on Mamba2 kernels, and I think in the bwd-pass we could save a significant chunk of computation saved if we use a custom kernel for the reverse cumsum. Specifically I am talking about line 437 in
ssd_combined.py
. Using the flip operator in Pytorch is a copy method, and thus fairly costly. While not incredibly essential, I do think optimizing this could save some computation for the bwd pass.I have written the following script to test out if this holds true.
And running it, we can see that we cut the runtime in 1/3. Here is the output:
Let me know if you would be interested in adding a small kernel to do this manually until PyTorch adds a reverse method. I would be very happy to draft up a PR, just want to check first before doing too much work.