state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.2k stars 1.12k forks source link

Optimizing the bwd pass of Mamba 2 #530

Closed Hprairie closed 2 months ago

Hprairie commented 3 months ago

I have been doing some tinkering on Mamba2 kernels, and I think in the bwd-pass we could save a significant chunk of computation saved if we use a custom kernel for the reverse cumsum. Specifically I am talking about line 437 in ssd_combined.py. Using the flip operator in Pytorch is a copy method, and thus fairly costly. While not incredibly essential, I do think optimizing this could save some computation for the bwd pass.

I have written the following script to test out if this holds true.

import torch

import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_H': 1}),
        triton.Config({'BLOCK_SIZE_H': 2}),
        triton.Config({'BLOCK_SIZE_H': 4}),
        triton.Config({'BLOCK_SIZE_H': 8}),
        triton.Config({'BLOCK_SIZE_H': 16}),
        triton.Config({'BLOCK_SIZE_H': 32}),
        triton.Config({'BLOCK_SIZE_H': 64}),
    ],
    key=['chunk_size', 'nheads'],
)
@triton.jit
def _rev_cumsum_kernel(
    x_ptr,
    chunk_size, nheads,
    x_batch_stride, x_chunk_stride, x_head_stride, x_csize_stride,
    BLOCK_SIZE_CHUNK: tl.constexpr,
    BLOCK_SIZE_H: tl.constexpr
    ):
    b_pid = tl.program_id(axis=0)
    chunk_pid = tl.program_id(axis=1)
    head_pid = tl.program_id(axis=2)

    x_ptr += b_pid * x_batch_stride + chunk_pid * x_chunk_stride

    offs_h = head_pid * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)[:, None]
    offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)[None, :]

    x_ptrs = x_ptr + offs_c * x_csize_stride + offs_h * x_head_stride

    x = tl.load(x_ptrs).to(tl.float32)
    x = tl.cumsum(x, reverse=True)
    x = tl.store(x_ptrs, x)

def _rev_cumsum(x: torch.Tensor):
    assert len(x.shape) == 4
    batch, nheads, nchunks, csize = x.shape

    grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
    with torch.cuda.device(x.device.index):
        _rev_cumsum_kernel[grid_chunk_cs](
            x,
            csize, nheads,
            x.stride(0), x.stride(2), x.stride(1), x.stride(3),
            BLOCK_SIZE_CHUNK=csize
        )

def init(batch, nheads, nchunks, chunk_size):
    x = torch.randn(batch, nheads, nchunks, chunk_size, device="cuda")
    return x

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["chunk_length"],  # Argument names to use as an x-axis for the plot.
        x_vals=[
            2**i for i in range(5, 10, 1)
        ],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg="provider",  # Argument name whose value corresponds to a different line in the plot.
        line_vals=["Triton", "Pytorch"],  # Possible values for `line_arg`.
        line_names=["Triton", "Pytorch"],  # Label name for the lines.
        styles=[("blue", "-"), ("green", "-")],  # Line styles.
        ylabel="ms",  # Label name for the y-axis.
        plot_name="Reverse Associative Scan Performance",  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    )
)
def benchmark_fused(chunk_length, provider):
    x = init(8, 64, 128, chunk_length)
    quantiles = [0.5, 0.2, 0.8]
    if provider == "Triton":
        ms, max_ms, min_ms = triton.testing.do_bench(
            lambda: _rev_cumsum(x),
            quantiles=quantiles,
            warmup=300,
            rep=1000,
        )
    if provider == "Pytorch":
        ms, max_ms, min_ms = triton.testing.do_bench(
            lambda: x.flip([-1]).cumsum(dim=-1).flip([-1]),
            quantiles=quantiles,
            warmup=300,
            rep=1000,
        )
    return ms, max_ms, min_ms
    gbps = lambda ms: 3 * exp.numel() * exp.element_size() / ms * 1e-6
    return gbps(ms), gbps(max_ms), gbps(min_ms)

if __name__ == "__main__":
    torch.manual_seed(0)
    device = torch.device("cuda")
    benchmark_fused.run(print_data=True, show_plots=True)

And running it, we can see that we cut the runtime in 1/3. Here is the output:

Reverse Associative Scan Performance:
   chunk_length    Triton   Pytorch
0          32.0  0.077824  1.842080
1          64.0  0.151552  1.961984
2         128.0  0.296960  0.881664
3         256.0  0.577536  1.754112
4         512.0  1.168384  3.497984

Let me know if you would be interested in adding a small kernel to do this manually until PyTorch adds a reverse method. I would be very happy to draft up a PR, just want to check first before doing too much work.

tridao commented 3 months ago

Thanks for the detailed investigation! This is awesome.

Last time I profiled, the reverse cumsum was taking around 10-20us compared to hundreds of microseconds for the other kernels. Have you profiled the whole backward pass?

Hprairie commented 3 months ago

I will take a deeper look into profiling the whole kernel, I also will say that I don't know how much this also matters with torch.compile(). I'll ping you if I find that there is a noticeable speedup, and draft a pr.

Hprairie commented 2 months ago

Hey Tri, after doing more profiling I can confirm that you are correct and that the overhead is essentially negligible and only roughly 20 us regardless of seqlen. Here is the benchmarking for the full bwd pass kernel with an in-place revsum vs. flipping.

Reverse Associative Scan Performance in Full Bwd Kernel of Mamba2:
   seqlen  Optimized Triton Kernel  Naive Pytorch Flipping
0    32.0                 0.214016                0.238592
1    64.0                 0.236544                0.238592
2   128.0                 0.222208                0.260096
3   256.0                 0.237568                0.268288
4   512.0                 0.274432                0.277504
5  1024.0                 0.360448                0.363520
6  2048.0                 0.654336                0.663552
7  4096.0                 1.661952                1.672192
8  8192.0                 3.486720                3.502080

It then seems not that beneficial to have a dedicated revsum kernel, even though it gives some slight performance gains. This is also in part due to Triton's issues with associative scan bugs.