srush / annotated-mamba

Annotated version of the Mamba paper
MIT License
453 stars 18 forks source link

Minimal mamba in JAX #1

Closed radarFudan closed 8 months ago

radarFudan commented 10 months ago

There is a minimal version in pytorch: https://github.com/johnma2006/mamba-minimal

Looking forward to a minimal version in JAX.

Then maybe we can try to accelerate it? (Not an expert in CUDA, is it possible to accelerate using mamba's original cuda code? )

srush commented 10 months ago

I was thinking I would just do this in pytorch as well... I enjoy coding in JAX, but there just doesn't seem to be that much use outside of Google.

davisyoshida commented 10 months ago

JAX is definitely fighting against some tough network effects with all the kernels etc coming out with torch implementations first. @radarFudan I'll probably end up porting this to JAX as part of reading it so maybe I can share it after.

Edit: Just noticed the "might take some time", so maybe I shouldn't get ahead of myself 😅

srush commented 10 months ago

I think that minimal-mamba linked above is supposed to be quite good. I think the annotated version will likely focus on visualizations and graphs to explain the core ideas. But if you just want to port code that other one is probably enough.

davisyoshida commented 10 months ago

@radarFudan I did a direct port last night, which "works", but the thing it's missing is the associative scan. I'm unsure whether it's possible to implement that with lax's associative scan primitive, I haven't thought about it much yet.

If the slow version would be enough, let me know and I can share it with you.

srush commented 10 months ago

Its definitely possible to implement with associative scan. The question is whether it will be fast enough on GPU.

Might be interesting to eventually try Pallas and write a low level version that works both on GPU and TPU.

(you guys are going to talk me into more jax programming aren't you)

davisyoshida commented 10 months ago

I'm pretty unsure about this since I just started looked at SSMs yesterday, but: The thing that seemed tricky about associative scan is that the body signature needs to be a -> a -> a, unlike jax.lax.scan. Since there's no separate "carry", I was thinking I'd have to have some extra unused arrays for the "output" values of A and B

Additionally, I think it'll also need to involve a branch since you need to do different behavior for each input depending on whether that input is an original input/internal node.

Is this what you were talking about when you said it might be slow on GPU?

srush commented 10 months ago

Here's the slide I had on the associative trick for time-invariant SSMs. Might need to adjust in the general case.

image

davisyoshida commented 10 months ago

Thanks that's super helpful, I'll try this out tonight.

radarFudan commented 10 months ago

Actually this associative scan has been implemented in the S5 repo. https://github.com/lindermanlab/S5/blob/main/s5/ssm.py

Therefore it should be similar to implement it in a selective-scan approach.

radarFudan commented 10 months ago

One thing I'm not sure about is whether JAX can handle the compilation (and back propagation) of following recurrent system with the same complexity as LTI system ($h_{k+1} = Ah_k + Bx_k$).

$$h_{k+1} = A(x_k) h_k + B(x_k)$$

(This is also associative so in theory it's feasible.)

davisyoshida commented 10 months ago

I used the method from the slide Sasha shared, and I'm seeing a slight slowdown over a sequential scan. I'm guessing it's not able to fuse the multiplication with C onto the associative_scan call, but I'm not 100% sure.

srush commented 10 months ago

Yeah, I think we're all saying the same thing (my slide is based on S5 implementation). I also found associative scan disappointingly slow when I tried it for this on GPU. Kind of a bummer, in theory this is the perfect use case. Maybe we can summon one of the JAX wizards to explain why this isn't super fast.

davisyoshida commented 10 months ago

Oh yeah I didn't post this yesterday , but it turns associative_scan isn't actually compiled to a single XLA op the way scan is. When I was looking at the compiled function to see what was getting fused, I saw a ton of repeated ops, and it turns out that's because the scan is done at the JAX level. (code)

So compilation time will actually scale with the input length, and you have to hope that XLA can turn a ton of slices and concats into something fast.

davisyoshida commented 10 months ago

Seems like pallas kernel will be the way to go

Skylion007 commented 10 months ago

Would it be possible to write a Triton version (https://github.com/jax-ml/jax-triton)? That would be useful in JAX + PyTorch (in theory anyway). It would also be compatible with Torch Compile.

srush commented 10 months ago

Yeah, jax-triton = Pallas. Think they use triton as the GPU backend.

I tried it for a different project though and it was still really early stage.

Ryu1845 commented 10 months ago

@johnryan465 is working on a fast version of parallel scan for a bounty. He has a repo with a bunch of implementations here: https://github.com/johnryan465/pscan . No JAX version but it should still be useful

radarFudan commented 10 months ago

I've tried to write a JAX-version of the mamba-minimal (https://github.com/radarFudan/mamba-minimal-jax). For now the JAX version is significantly slower (5-6x) than the PyTorch version.

Howuhh commented 10 months ago

I think this can be very helpful: https://x.com/mkinyugo/status/1740430275375714336?s=46. Parallel scan can be implemented in a way proposed here https://github.com/glassroom/heinsen_sequence, with pure pytorch and with GPU benefits.

srush commented 10 months ago

Have you tried training with this? I had checked it out before, but I wasn't sure if it was stable. If it works it fixes lots of things.

Howuhh commented 10 months ago

Not yet, but it is at least faster judging by the tweet...

But I have often used parallel scan in jax. Can't it be used here similar to S5? It would also solve a lot of things if written in jax.

srush commented 10 months ago

As others have noted above it seems to be slow in this case. Doesn't fuse well. S5 gets away from these issues by doing MIMO.

johnryan465 commented 10 months ago

I think this can be very helpful: https://x.com/mkinyugo/status/1740430275375714336?s=46. Parallel scan can be implemented in a way proposed here https://github.com/glassroom/heinsen_sequence, with pure pytorch and with GPU benefits.

I'm the author of the https://github.com/johnryan465/pscan repo mentioned above and have benchmarked and tested a few different approaches including that one. It does work but at long sequence lengths you do start to see numerical errors creeping in, (ie torch all close fails).

Howuhh commented 10 months ago

Yup, maybe it will not work for pre-trained models, but can work for training new ones faster with simpler code. For educational purposed and small use-cases this is better than naive sequential scan.

vvvm23 commented 9 months ago

I've created a repository that implements Mamba in JAX with associative scans. It isn't feature-complete yet or 100% verified to be correct, but seems quite performant. It still lags behind the reference PyTorch+CUDA implementation in speed, however, so I will explore a Pallas kernel next in this repository.

It may be of interest to those in this thread https://github.com/vvvm23/mamba-jax for your own implementations.

davisyoshida commented 9 months ago

Just an update on the pallas angle, it seems that triton doesn't support a multi-tensor associative scan (openai/triton#2657). So to do it in pallas you probably need to implement it yourself, and I don't know how that will compare on efficiency. pl.load doesn't support strided slicing, and once you've loaded from a ref it doesn't look like you can slice at all, so I'm not sure if it's even possible atm. Take that all with a mound of salt since I'm just starting to look at pallas.

srush commented 9 months ago

Bummer. Maybe it is worth trying to merge the tensors and paying the memory overhead just to see if it works at all in triton.

Just saw this pytorch version as well https://github.com/alxndrTL/mamba.py which seems to have good perf. @alxndrTL

(I haven't made too much annotated-mamba progress, but I will likely have a documented pytorch pscan and a heinsen version. I'll include you guys as authors if that is okay. )

srush commented 9 months ago

Think I got a Triton version working! passes my tests at least. (based on @proger 's trick here https://github.com/proger/accelerated-scan/blob/main/accelerated_scan/triton.py)

import torch
import triton
import triton.language as tl

@triton.jit
def unpack64(merged):
    tl.static_assert(merged.dtype == tl.uint64)
    b = (merged & 0xFFFFFFFF).to(tl.uint32).to(tl.float32, bitcast=True)
    a = (merged >> 32).to(tl.uint32).to(tl.float32, bitcast=True)
    return a, b

@triton.jit
def pack64(a, b):
    tl.static_assert(a.dtype == tl.float32)
    tl.static_assert(b.dtype == tl.float32)
    a = a.to(dtype=tl.uint32, bitcast=True).to(tl.uint64)
    a = a << 32
    b = b.to(dtype=tl.uint32, bitcast=True).to(tl.uint64)
    return a | b

@triton.jit
def first_order_op(l, r):
    fl, xl = unpack64(l)
    fr, xr = unpack64(r)
    f = fl * fr
    x = fr * xl + xr
    return pack64(f, x)

@triton.jit
def mamba_kernel(A_ptr, B_ptr, C_ptr, d_ptr, x_ptr, y_ptr, B,  L, D, 
                 N:tl.constexpr, L_SIZE: tl.constexpr, D_SIZE: tl.constexpr):
    # A_ptr: D x N
    # B_ptr, C_ptr: B x L x N
    # d_ptr: B x D x L
    # x, y: B x D x L

    # Loads
    Ls, Ns, Ds = tl.arange(0, L_SIZE), tl.arange(0, N)[:, None], tl.arange(0, D_SIZE)[:, None]
    B_idx, D_idx = tl.program_id(0), tl.program_id(1) * D_SIZE
    d = tl.load(d_ptr + L * (D_idx + Ds) + D * L * B_idx + Ls) # D x L
    A = tl.load(A_ptr + N * D_idx + Ns) # N x 1
    B = tl.load(B_ptr + L * N * B_idx + N * Ls + Ns) # N x L
    C = tl.load(C_ptr + L * N * B_idx + N * Ls + Ns) # N x L
    x = tl.load(x_ptr + L * D * B_idx + L * (D_idx + Ds) + Ls) # D x L

    # Discretize
    A = A[:, :, None]
    A_bar = tl.exp(d * A)
    B_bar = ((A_bar - A) / A) * B[:, None, :]
    h = B_bar * x # N x D x L

    # Scan
    stack = pack64(A_bar, h)
    h = tl.associative_scan(stack, 2, first_order_op)
    _, h = unpack64(h)
    y = tl.sum(C[:, None, :] * h, 0) # D x L

    # Write out
    tl.store(y_ptr + L * D * 2 * B_idx  + L * (D_idx + Ds) + Ls, y)
davisyoshida commented 9 months ago

The packing method is way smarter than then weird hacks I was trying to use to get it working with pallas, I'll see if I can make this work with JAX as well.

srush commented 9 months ago

haha, yeah I tried some crazy hacks too. this one feels pretty clean and should work nicely for backwards and f16 as well.

davisyoshida commented 9 months ago

I'm guessing there's no way to beat Triton's associative scan, since the only ways I could accomplish strided slices and updates involved writes to global memory.

vvvm23 commented 9 months ago

Nice work! Might give this a shot porting to Pallas and see if a similar method can work for TPU backends too

davisyoshida commented 9 months ago

I got a Pallas kernel working using the packing trick. Unfortunately, it's still losing to jax.lax.associative_scan once the sequence length gets long enough. These numbers are from a Turing GPU, the results might be different on more recent hardware.

image

vvvm23 commented 9 months ago

Could you share code for this? I can run it on a more modern GPU and see what the results are

davisyoshida commented 9 months ago

Here's the code, it's a port of the mamba-minimal code, but with the sequence dimension at the end. Making the block size or sequence length larger leads to significantly increased compilation times, but I haven't dug into why yet.

from functools import partial

import jax
from jax.core import Primitive, ShapedArray
import jax.numpy as jnp
import jax.experimental.pallas as pl
from jax._src.pallas.triton.lowering import triton_lowering_rules, _associative_scan_lowering

import triton.language as tl

# ------ Define lowering rules for casting ------

type_map = {
    jnp.float16.dtype: tl.float16,
    jnp.bfloat16.dtype: tl.bfloat16,
    jnp.float32.dtype: tl.float32,
    jnp.float64.dtype: tl.float64,
    jnp.uint8.dtype: tl.uint8,
    jnp.uint16.dtype: tl.uint16,
    jnp.uint32.dtype: tl.uint32,
    jnp.uint64.dtype: tl.uint64,
    jnp.int8.dtype: tl.int8,
    jnp.int16.dtype: tl.int16,
    jnp.int32.dtype: tl.int32,
    jnp.int64.dtype: tl.int64,
}

def bitcast_convert_type_lowering(ctx, arg, new_dtype):
    out_dtype = type_map[new_dtype]
    return arg.to(out_dtype, bitcast=True, _builder=ctx.builder)

triton_lowering_rules[jax.lax.bitcast_convert_type_p] = bitcast_convert_type_lowering

def convert_element_type_lowering(ctx, arg, new_dtype, weak_type):
    out_dtype = type_map[new_dtype]
    return arg.to(out_dtype, _builder=ctx.builder)

triton_lowering_rules[jax.lax.convert_element_type_p] = convert_element_type_lowering

# ------ Packing ------
# Compiling these individually as in Sasha's code may be better
def unpack(val):
    assert val.dtype == jnp.uint64
    mask = jnp.array(0xFFFFFFFF, dtype=jnp.uint64)
    a = (val & mask).astype(jnp.uint32)
    a = jax.lax.bitcast_convert_type(
        a,
        jnp.float32
    )

    b = (val >> 32).astype(jnp.uint32)
    b = jax.lax.bitcast_convert_type(
        b,
        jnp.float32
    )
    return a, b

def pack(a, b):
    assert a.shape == b.shape
    assert a.dtype == jnp.float32
    assert b.dtype == jnp.float32
    a = jax.lax.bitcast_convert_type(a, jnp.uint32).astype(jnp.uint64)
    b = jax.lax.bitcast_convert_type(b, jnp.uint32).astype(jnp.uint64)
    assert a.dtype == jnp.uint64
    assert b.dtype == jnp.uint64
    result = (b << 32) | a
    assert result.dtype == jnp.uint64
    return result

# ------ Scan primitive ------
mamba_scan_p = Primitive('mamba_scan')
def mamba_scan_impl(x):
    raise NotImplementedError

def mamba_scan_abstract_eval(x):
    return ShapedArray(x.shape, x.dtype)

mamba_scan_p.def_impl(mamba_scan_impl)
mamba_scan_p.def_abstract_eval(mamba_scan_abstract_eval)

def scan_op(l, r):
    a1, b1 = unpack(l)
    a2, b2 = unpack(r)

    result = (a1 * a2, a2 * b1 + b2)
    return pack(*result)

def scan_lowering(ctx, arg):
    # I'm not sure why I need to use `enable_x64` again,
    # but removing it causes a crash
    with jax.experimental.enable_x64():
        return _associative_scan_lowering(
            scan_op, ctx, arg, (1,)
        )[0]

triton_lowering_rules[mamba_scan_p] = scan_lowering

# ------ Main kernel ------
def mamba_kernel(u, delta, A, B, C, D, out_ref):
    """
    u: (d_inner, seq)
    delta: (d_inner, seq)
    A: (d_inner, channels)
    B: (channels, seq)
    C: (channels, seq)
    D: (d_inner,)
    out_ref: (d_inner, seq)
    """
    delta = delta[...][:, None, :]
    u = u[...]

    deltaA = jnp.exp(delta * A[...][:, :, None])
    deltaBu = delta * B[...][None, :, :] * u[:, None, :]

    packed = pack(deltaA, deltaBu)
    result = mamba_scan_p.bind(packed)

    scan_output = unpack(result)[1]

    y = jnp.sum(scan_output * C[...][None, :, :], axis=1)

    out_ref[...] = y + u * D[...][:, None]

# ------ Interface ------
@partial(jax.jit, static_argnames=('block_size',))
def selective_scan(u, delta, A, B, C, D, block_size=8):
    d_inner, seq = u.shape

    out_struct = jax.ShapeDtypeStruct(
        shape=u.shape,
        dtype=delta.dtype,
    )

    return pl.pallas_call(
        mamba_kernel,
        out_shape=out_struct,
        grid=(d_inner // block_size,),
        in_specs=[
            pl.BlockSpec(lambda i: (i, 0), (block_size, u.shape[1])),
            pl.BlockSpec(lambda i: (i, 0), (block_size, delta.shape[1])),
            pl.BlockSpec(lambda i: (i, 0), (block_size, A.shape[1])),
            pl.BlockSpec(lambda i: (0, 0), B.shape),
            pl.BlockSpec(lambda i: (0, 0), C.shape),
            pl.BlockSpec(lambda i: (i,), (block_size,)),
        ],
        out_specs=pl.BlockSpec(lambda i: (i, 0), (block_size, u.shape[1])),
    )(u, delta, A, B, C, D)
srush commented 9 months ago

Very neat that it runs. Will be interesting to understand why it is slower.

I'm working on the backwards implementation which is a bit more annoying (can pallas autodiff your version?). Anyone know how to flip/reverse a matrix in triton? Do I have to make a matrix to do it

davisyoshida commented 9 months ago

Do you think we need to block over length?

I was wondering about this as well. One possibility is to do blocks sequentially and pass the final value from the previous block into the kernel. It's possible that this is effectively what jax.lax.associative_scan ends up doing with its various fused blocks (see below).

Do you know what associative scan does

I'm not sure what XLA will do in general, but on my system it creates a different number of fused ops depending on what I use for the sequence length. Here's some code to inspect the compiled scan:

Code ```python from functools import partial import jax import jax.numpy as jnp def op(l, r): return ( l[0] * r[0], l[1] * r[0] + r[1] ) def f(x): return jax.lax.associative_scan(op, x) def main(): n = 2 args = ( jnp.arange(n, dtype=jnp.float32), jnp.arange(n, dtype=jnp.float32), ) jitted_f = jax.jit(partial(jax.lax.associative_scan, op)) compiled = jitted_f.lower(args).compile() print(compiled.as_text()) if __name__ == '__main__': main() ```

Here are the compiled outputs:

n=2 ``` HloModule jit__unnamed_wrapped_function_, is_scheduled=true, entry_computation_layout={(f32[2]{0}, f32[2]{0})->(f32[2]{0}, f32[2]{0})}, allow_spmd_sharding_propagation_to_output={true,true} %fused_computation (param_0.4: f32[2], param_1.8: f32[2]) -> (f32[2], f32[2]) { %param_0.4 = f32[2]{0} parameter(0) %slice.20 = f32[1]{0} slice(f32[2]{0} %param_0.4), slice={[0:1]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %constant_0 = f32[] constant(0) %pad.15 = f32[2]{0} pad(f32[1]{0} %slice.20, f32[] %constant_0), padding=0_1, metadata={op_name="jit()/jit(main)/pad[padding_config=((0, 1, 1),)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %slice.19 = f32[1]{0} slice(f32[2]{0} %param_0.4), slice={[0:1:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(1,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %param_1.8 = f32[2]{0} parameter(1) %slice.23 = f32[1]{0} slice(f32[2]{0} %param_1.8), slice={[1:2:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(1,) limit_indices=(2,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %multiply.4 = f32[1]{0} multiply(f32[1]{0} %slice.19, f32[1]{0} %slice.23), metadata={op_name="jit()/jit(main)/mul" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=8} %slice.18 = f32[1]{0} slice(f32[2]{0} %param_0.4), slice={[1:2:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(1,) limit_indices=(2,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %add.7 = f32[1]{0} add(f32[1]{0} %multiply.4, f32[1]{0} %slice.18), metadata={op_name="jit()/jit(main)/add" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=8} %pad.12 = f32[2]{0} pad(f32[1]{0} %add.7, f32[] %constant_0), padding=1_0, metadata={op_name="jit()/jit(main)/pad[padding_config=((1, 0, 1),)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %add.6 = f32[2]{0} add(f32[2]{0} %pad.15, f32[2]{0} %pad.12), metadata={op_name="jit()/jit(main)/add" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %slice.22.clone.1 = f32[1]{0} slice(f32[2]{0} %param_1.8), slice={[0:1]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %pad.19.clone.1 = f32[2]{0} pad(f32[1]{0} %slice.22.clone.1, f32[] %constant_0), padding=0_1, metadata={op_name="jit()/jit(main)/pad[padding_config=((0, 1, 1),)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %slice.21.clone.1 = f32[1]{0} slice(f32[2]{0} %param_1.8), slice={[0:1:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(1,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %multiply.5.clone.1 = f32[1]{0} multiply(f32[1]{0} %slice.21.clone.1, f32[1]{0} %slice.23), metadata={op_name="jit()/jit(main)/mul" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=7} %pad.18.clone.1 = f32[2]{0} pad(f32[1]{0} %multiply.5.clone.1, f32[] %constant_0), padding=1_0, metadata={op_name="jit()/jit(main)/pad[padding_config=((1, 0, 1),)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %add.8.clone.1 = f32[2]{0} add(f32[2]{0} %pad.19.clone.1, f32[2]{0} %pad.18.clone.1), metadata={op_name="jit()/jit(main)/add" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} ROOT %tuple = (f32[2]{0}, f32[2]{0}) tuple(f32[2]{0} %add.6, f32[2]{0} %add.8.clone.1) } ENTRY %main.20 (Arg_0.1: f32[2], Arg_1.2: f32[2]) -> (f32[2], f32[2]) { %Arg_1.2 = f32[2]{0} parameter(1), sharding={replicated} %Arg_0.1 = f32[2]{0} parameter(0), sharding={replicated} %fusion = (f32[2]{0}, f32[2]{0}) fusion(f32[2]{0} %Arg_1.2, f32[2]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit()/jit(main)/add" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %get-tuple-element.1 = f32[2]{0} get-tuple-element((f32[2]{0}, f32[2]{0}) %fusion), index=1 %get-tuple-element = f32[2]{0} get-tuple-element((f32[2]{0}, f32[2]{0}) %fusion), index=0 ROOT %tuple.19 = (f32[2]{0}, f32[2]{0}) tuple(f32[2]{0} %get-tuple-element.1, f32[2]{0} %get-tuple-element), frontend_attributes={fingerprint_before_lhs="ab53525be2df83cba0765714159b08c5"} } ```
n=4 ``` HloModule jit__unnamed_wrapped_function_, is_scheduled=true, entry_computation_layout={(f32[4]{0}, f32[4]{0})->(f32[4]{0}, f32[4]{0})}, allow_spmd_sharding_propagation_to_output={true,true} %fused_computation.1 (param_0.37: f32[4], param_1.45: f32[4]) -> f32[2] { %param_1.45 = f32[4]{0} parameter(1) %slice.98 = f32[2]{0} slice(f32[4]{0} %param_1.45), slice={[0:3:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(3,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %param_0.37 = f32[4]{0} parameter(0) %slice.97 = f32[2]{0} slice(f32[4]{0} %param_0.37), slice={[1:4:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(1,) limit_indices=(4,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %multiply.37 = f32[2]{0} multiply(f32[2]{0} %slice.98, f32[2]{0} %slice.97), metadata={op_name="jit()/jit(main)/mul" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=8} %slice.96 = f32[2]{0} slice(f32[4]{0} %param_1.45), slice={[1:4:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(1,) limit_indices=(4,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %add.26 = f32[2]{0} add(f32[2]{0} %multiply.37, f32[2]{0} %slice.96), metadata={op_name="jit()/jit(main)/add" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=8} %slice.52 = f32[1]{0} slice(f32[2]{0} %add.26), slice={[0:1:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(1,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %slice.78 = f32[2]{0} slice(f32[4]{0} %param_0.37), slice={[0:3:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(3,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %multiply.27 = f32[2]{0} multiply(f32[2]{0} %slice.78, f32[2]{0} %slice.97), metadata={op_name="jit()/jit(main)/mul" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=7} %slice.62 = f32[1]{0} slice(f32[2]{0} %multiply.27), slice={[1:2:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(1,) limit_indices=(2,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %multiply.17 = f32[1]{0} multiply(f32[1]{0} %slice.52, f32[1]{0} %slice.62), metadata={op_name="jit()/jit(main)/mul" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=8} %slice.51 = f32[1]{0} slice(f32[2]{0} %add.26), slice={[1:2:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(1,) limit_indices=(2,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %add.19 = f32[1]{0} add(f32[1]{0} %multiply.17, f32[1]{0} %slice.51), metadata={op_name="jit()/jit(main)/add" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=8} %constant_1 = f32[] constant(0) ROOT %pad.26 = f32[2]{0} pad(f32[1]{0} %add.19, f32[] %constant_1), padding=1_0, metadata={op_name="jit()/jit(main)/pad[padding_config=((1, 0, 1),)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} } %fused_computation.4 (param_0.35: f32[4], param_1.57: f32[4], param_2.42: f32[2]) -> (f32[4], f32[4]) { %param_0.35 = f32[4]{0} parameter(0) %slice.57 = f32[1]{0} slice(f32[4]{0} %param_0.35), slice={[0:1]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %slice.84 = f32[2]{0} slice(f32[4]{0} %param_0.35), slice={[0:3:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(3,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %slice.83 = f32[2]{0} slice(f32[4]{0} %param_0.35), slice={[1:4:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(1,) limit_indices=(4,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %multiply.29 = f32[2]{0} multiply(f32[2]{0} %slice.84, f32[2]{0} %slice.83), metadata={op_name="jit()/jit(main)/mul" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=7} %slice.82 = f32[1]{0} slice(f32[2]{0} %multiply.29), slice={[0:1]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %constant_4 = f32[] constant(0) %pad.33 = f32[2]{0} pad(f32[1]{0} %slice.82, f32[] %constant_4), padding=0_1, metadata={op_name="jit()/jit(main)/pad[padding_config=((0, 1, 1),)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %slice.90 = f32[1]{0} slice(f32[2]{0} %multiply.29), slice={[0:1:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(1,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %slice.89 = f32[1]{0} slice(f32[2]{0} %multiply.29), slice={[1:2:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(1,) limit_indices=(2,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %multiply.34 = f32[1]{0} multiply(f32[1]{0} %slice.90, f32[1]{0} %slice.89), metadata={op_name="jit()/jit(main)/mul" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=7} %pad.35 = f32[2]{0} pad(f32[1]{0} %multiply.34, f32[] %constant_4), padding=1_0, metadata={op_name="jit()/jit(main)/pad[padding_config=((1, 0, 1),)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %add.23 = f32[2]{0} add(f32[2]{0} %pad.33, f32[2]{0} %pad.35), metadata={op_name="jit()/jit(main)/add" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %slice.56 = f32[1]{0} slice(f32[2]{0} %add.23), slice={[0:1]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %slice.64 = f32[1]{0} slice(f32[4]{0} %param_0.35), slice={[2:4:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(2,) limit_indices=(4,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %multiply.19 = f32[1]{0} multiply(f32[1]{0} %slice.56, f32[1]{0} %slice.64), metadata={op_name="jit()/jit(main)/mul" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=7} %concatenate.5 = f32[2]{0} concatenate(f32[1]{0} %slice.57, f32[1]{0} %multiply.19), dimensions={0}, metadata={op_name="jit()/jit(main)/concatenate[dimension=0]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %pad.29 = f32[4]{0} pad(f32[2]{0} %concatenate.5, f32[] %constant_4), padding=0_1_1, metadata={op_name="jit()/jit(main)/pad[padding_config=((0, 1, 1),)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %pad.28 = f32[4]{0} pad(f32[2]{0} %add.23, f32[] %constant_4), padding=1_0_1, metadata={op_name="jit()/jit(main)/pad[padding_config=((1, 0, 1),)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %add.21 = f32[4]{0} add(f32[4]{0} %pad.29, f32[4]{0} %pad.28), metadata={op_name="jit()/jit(main)/add" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %param_1.57 = f32[4]{0} parameter(1) %slice.50.clone.1 = f32[1]{0} slice(f32[4]{0} %param_1.57), slice={[0:1]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %slice.112.clone.1 = f32[2]{0} slice(f32[4]{0} %param_1.57), slice={[0:3:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(3,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %multiply.41.clone.1 = f32[2]{0} multiply(f32[2]{0} %slice.112.clone.1, f32[2]{0} %slice.83), metadata={op_name="jit()/jit(main)/mul" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=8} %slice.110.clone.1 = f32[2]{0} slice(f32[4]{0} %param_1.57), slice={[1:4:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(1,) limit_indices=(4,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %add.30.clone.1 = f32[2]{0} add(f32[2]{0} %multiply.41.clone.1, f32[2]{0} %slice.110.clone.1), metadata={op_name="jit()/jit(main)/add" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=8} %slice.109.clone.1 = f32[1]{0} slice(f32[2]{0} %add.30.clone.1), slice={[0:1]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %pad.39.clone.1 = f32[2]{0} pad(f32[1]{0} %slice.109.clone.1, f32[] %constant_4), padding=0_1, metadata={op_name="jit()/jit(main)/pad[padding_config=((0, 1, 1),)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %param_2.42 = f32[2]{0} parameter(2) %add.18.clone.1 = f32[2]{0} add(f32[2]{0} %pad.39.clone.1, f32[2]{0} %param_2.42), metadata={op_name="jit()/jit(main)/add" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %slice.49.clone.1 = f32[1]{0} slice(f32[2]{0} %add.18.clone.1), slice={[0:1]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %multiply.14.clone.1 = f32[1]{0} multiply(f32[1]{0} %slice.49.clone.1, f32[1]{0} %slice.64), metadata={op_name="jit()/jit(main)/mul" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=8} %slice.48.clone.1 = f32[1]{0} slice(f32[4]{0} %param_1.57), slice={[2:4:2]}, metadata={op_name="jit()/jit(main)/slice[start_indices=(2,) limit_indices=(4,) strides=(2,)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %add.16.clone.1 = f32[1]{0} add(f32[1]{0} %multiply.14.clone.1, f32[1]{0} %slice.48.clone.1), metadata={op_name="jit()/jit(main)/add" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=8} %concatenate.4.clone.1 = f32[2]{0} concatenate(f32[1]{0} %slice.50.clone.1, f32[1]{0} %add.16.clone.1), dimensions={0}, metadata={op_name="jit()/jit(main)/concatenate[dimension=0]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %pad.25.clone.1 = f32[4]{0} pad(f32[2]{0} %concatenate.4.clone.1, f32[] %constant_4), padding=0_1_1, metadata={op_name="jit()/jit(main)/pad[padding_config=((0, 1, 1),)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %pad.22.clone.1 = f32[4]{0} pad(f32[2]{0} %add.18.clone.1, f32[] %constant_4), padding=1_0_1, metadata={op_name="jit()/jit(main)/pad[padding_config=((1, 0, 1),)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} %add.15.clone.1 = f32[4]{0} add(f32[4]{0} %pad.25.clone.1, f32[4]{0} %pad.22.clone.1), metadata={op_name="jit()/jit(main)/add" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} ROOT %tuple = (f32[4]{0}, f32[4]{0}) tuple(f32[4]{0} %add.21, f32[4]{0} %add.15.clone.1) } ENTRY %main.44 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> (f32[4], f32[4]) { %Arg_1.2 = f32[4]{0} parameter(1), sharding={replicated} %Arg_0.1 = f32[4]{0} parameter(0), sharding={replicated} %fusion.1 = f32[2]{0} fusion(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit()/jit(main)/pad[padding_config=((1, 0, 1),)]" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} ROOT %fusion.4 = (f32[4]{0}, f32[4]{0}) fusion(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2, f32[2]{0} %fusion.1), kind=kLoop, calls=%fused_computation.4, frontend_attributes={fingerprint_before_lhs="ccdc28c0ab7f4fba37664fb59b6718ef"}, metadata={op_name="jit()/jit(main)/add" source_file="/home/davis/src/learning_pallas/scan_jaxpr.py" source_line=21} } ```

If I kick the length up to 4096 there are a ton of separate fused blocks, so if you need to write to global memory to transition between those blocks, this will be reading and writing to HBM a lot.

can pallas autodiff your version?

No since it won't know how to do a backward pass for the mamba_scan_p primitive that I added. There's also the fact that the packing isn't differentiable.

srush commented 9 months ago

Okay, fun. Let do a version that blocks on length and saves the final h state for each block.

Also seems flipping a matrix is currently pretty hard to do in triton, so might need a trick there.

Wish I didn't have a ton of real work to do.

davisyoshida commented 9 months ago

Let do a version that blocks on length and saves the final h state for each block.

By the way, I did end up giving this a shot, but it didn't seem obviously possible to implement a carry as would be necessary.

srush commented 9 months ago

Oh sorry, I got distracted trying to get autodiff in triton working https://github.com/srush/triton-autodiff . (Obviously Pallas is better, but I can't figure out how to install it. )

What do you think about three 3 step approach to a carry? I was able to get this to work for a simple scan.

@triton.jit
def scan(A_bar, h, h1, h2, L_SIZE: tl.constexpr):
    stack = pack64(A_bar, h)
    h_init = pack64(h1, h2)
    init_stack = first_order_op(h_init, stack)
    stack = tl.where(tl.arange(0, L_SIZE) == 0, init_stack, stack)
    h = tl.associative_scan(stack, -1, first_order_op)
    h1, h2 = unpack64(h)
    return h1, h2

@triton.jit
def reduce(A_bar, h):
    stack = pack64(A_bar, h)
    h = tl.reduce(stack, -1, first_order_op)
    return unpack64(h)

@triton.jit
def step1_kernel(A_ptr, B_ptr, h1_ptr, h2_ptr, L_BLOCKSIZE: tl.constexpr):
    section = tl.program_id(0)
    section_start = tl.program_id(0) * L_BLOCKSIZE
    Ls = section_start + tl.arange(0, L_BLOCKSIZE)
    A, B = tl.load(A_ptr + Ls), tl.load(B_ptr + Ls)
    h1, h2 = reduce(A, B)
    tl.store(h1_ptr + section + 1, h1)
    tl.store(h2_ptr + section + 1, h2)

@triton.jit
def step2_kernel(h1_ptr, h2_ptr, L_BLOCKSIZE: tl.constexpr):
    Ls = tl.arange(0, L_BLOCKSIZE)
    h1, h2 = tl.load(h1_ptr + Ls), tl.load(h2_ptr + Ls)
    h1, h2 = scan(h1, h2, h1, h2, L_SIZE=L_BLOCKSIZE)
    tl.store(h1_ptr + Ls, h1)
    tl.store(h2_ptr + Ls, h2)

@triton.jit
def step3_kernel(A_ptr, B_ptr, h1_ptr, h2_ptr, y_ptr, L_BLOCKSIZE: tl.constexpr):
    section = tl.program_id(0)
    section_start = tl.program_id(0) * L_BLOCKSIZE
    Ls = section_start + tl.arange(0, L_BLOCKSIZE)
    A, B = tl.load(A_ptr + Ls), tl.load(B_ptr + Ls)
    h1 = tl.load(h1_ptr + section)
    h2 = tl.load(h2_ptr + section)
    _, y = scan(A, B, h1, h2, L_SIZE=L_BLOCKSIZE)
    tl.store(y_ptr + Ls, y)
davisyoshida commented 9 months ago

Oh yep that'll do it. My anti code duplication impulses did me a disservice here since I was trying to figure out how to get a single scan setup to work for the different cases. Nice!

srush commented 9 months ago

Oh here is a nice reference

https://mit-gfx.github.io/recfilter/supplement.pdf

image image

srush commented 8 months ago

So I think this paper https://arxiv.org/abs/2402.19427 kind of answers this thread. It sounds like the right way to do this in Jax is simply as a linear scan. I think that means it is really just a basic scan loop in pallas.

image

Kind of a bummer, I was really hoping the associative scan would work out nicely here.

davisyoshida commented 8 months ago

I saw that paper, but it sounded like it might have come down to how the performance of memory access on TPUs. On GPUs I think we don't know what's best for this architecture yet, right?

srush commented 8 months ago

That's fair, but I found the discussion in this paper relatively convincing. These models are big enough that the compute is dwarfed by the memory loads, so the fact that associative scan has better compute is washed out.

If I find time I will implement a simple for-loop scan in triton to compare with my complex associative scan.

srush commented 8 months ago

btw, many of the other issues we discussed in this thread are now fixed in Triton 3.0. You don't need packing any more for the associative scan, there is now a way to do np.roll, and you can reverse associative scan. (not that it helps yet)

davisyoshida commented 8 months ago

Didn't know about the no-packing change, that's great! I did see you put in the other changes though.

MzeroMiko commented 7 months ago

I think this can be very helpful: https://x.com/mkinyugo/status/1740430275375714336?s=46. Parallel scan can be implemented in a way proposed here https://github.com/glassroom/heinsen_sequence, with pure pytorch and with GPU benefits.

I'm the author of the https://github.com/johnryan465/pscan repo mentioned above and have benchmarked and tested a few different approaches including that one. It does work but at long sequence lengths you do start to see numerical errors creeping in, (ie torch all close fails).

I found that my one of my early attempt trying to write mamba in pure torch is quite similar to yours. (Though I have to admit I am later, I derived that approach independently.)

I deal with the "numerical errors" with chunks, but the chunk size should be less than 64 or 32 in practice to keep numerical stability. And for the speed, it is faster than "mamba-minimal" and "mamba.py", but is quite slower than "mamba".

Then I tried writing this in "torch.autograd.Function", but the speed almost has no change, I tried triton also, but as I am a beginner in triton, I did not finish that.

I also have some other ideas about chunk-wise approaches, but that seems that works good theoritically, but not practically.

So now I am here searching for a way that is more general but with nice speed compared to mamba.