Closed radarFudan closed 8 months ago
I was thinking I would just do this in pytorch as well... I enjoy coding in JAX, but there just doesn't seem to be that much use outside of Google.
JAX is definitely fighting against some tough network effects with all the kernels etc coming out with torch implementations first. @radarFudan I'll probably end up porting this to JAX as part of reading it so maybe I can share it after.
Edit: Just noticed the "might take some time", so maybe I shouldn't get ahead of myself 😅
I think that minimal-mamba linked above is supposed to be quite good. I think the annotated version will likely focus on visualizations and graphs to explain the core ideas. But if you just want to port code that other one is probably enough.
@radarFudan I did a direct port last night, which "works", but the thing it's missing is the associative scan. I'm unsure whether it's possible to implement that with lax's associative scan primitive, I haven't thought about it much yet.
If the slow version would be enough, let me know and I can share it with you.
Its definitely possible to implement with associative scan. The question is whether it will be fast enough on GPU.
Might be interesting to eventually try Pallas and write a low level version that works both on GPU and TPU.
(you guys are going to talk me into more jax programming aren't you)
I'm pretty unsure about this since I just started looked at SSMs yesterday, but: The thing that seemed tricky about associative scan is that the body signature needs to be a -> a -> a, unlike jax.lax.scan. Since there's no separate "carry", I was thinking I'd have to have some extra unused arrays for the "output" values of A and B
Additionally, I think it'll also need to involve a branch since you need to do different behavior for each input depending on whether that input is an original input/internal node.
Is this what you were talking about when you said it might be slow on GPU?
Here's the slide I had on the associative trick for time-invariant SSMs. Might need to adjust in the general case.
Thanks that's super helpful, I'll try this out tonight.
Actually this associative scan has been implemented in the S5 repo. https://github.com/lindermanlab/S5/blob/main/s5/ssm.py
Therefore it should be similar to implement it in a selective-scan approach.
One thing I'm not sure about is whether JAX can handle the compilation (and back propagation) of following recurrent system with the same complexity as LTI system ($h_{k+1} = Ah_k + Bx_k$).
$$h_{k+1} = A(x_k) h_k + B(x_k)$$
(This is also associative so in theory it's feasible.)
I used the method from the slide Sasha shared, and I'm seeing a slight slowdown over a sequential scan. I'm guessing it's not able to fuse the multiplication with C onto the associative_scan
call, but I'm not 100% sure.
Yeah, I think we're all saying the same thing (my slide is based on S5 implementation). I also found associative scan disappointingly slow when I tried it for this on GPU. Kind of a bummer, in theory this is the perfect use case. Maybe we can summon one of the JAX wizards to explain why this isn't super fast.
Oh yeah I didn't post this yesterday , but it turns associative_scan isn't actually compiled to a single XLA op the way scan is. When I was looking at the compiled function to see what was getting fused, I saw a ton of repeated ops, and it turns out that's because the scan is done at the JAX level. (code)
So compilation time will actually scale with the input length, and you have to hope that XLA can turn a ton of slices and concats into something fast.
Seems like pallas kernel will be the way to go
Would it be possible to write a Triton version (https://github.com/jax-ml/jax-triton)? That would be useful in JAX + PyTorch (in theory anyway). It would also be compatible with Torch Compile.
Yeah, jax-triton = Pallas. Think they use triton as the GPU backend.
I tried it for a different project though and it was still really early stage.
@johnryan465 is working on a fast version of parallel scan for a bounty. He has a repo with a bunch of implementations here: https://github.com/johnryan465/pscan . No JAX version but it should still be useful
I've tried to write a JAX-version of the mamba-minimal (https://github.com/radarFudan/mamba-minimal-jax). For now the JAX version is significantly slower (5-6x) than the PyTorch version.
I think this can be very helpful: https://x.com/mkinyugo/status/1740430275375714336?s=46. Parallel scan can be implemented in a way proposed here https://github.com/glassroom/heinsen_sequence, with pure pytorch and with GPU benefits.
Have you tried training with this? I had checked it out before, but I wasn't sure if it was stable. If it works it fixes lots of things.
Not yet, but it is at least faster judging by the tweet...
But I have often used parallel scan in jax. Can't it be used here similar to S5? It would also solve a lot of things if written in jax.
As others have noted above it seems to be slow in this case. Doesn't fuse well. S5 gets away from these issues by doing MIMO.
I think this can be very helpful: https://x.com/mkinyugo/status/1740430275375714336?s=46. Parallel scan can be implemented in a way proposed here https://github.com/glassroom/heinsen_sequence, with pure pytorch and with GPU benefits.
I'm the author of the https://github.com/johnryan465/pscan repo mentioned above and have benchmarked and tested a few different approaches including that one. It does work but at long sequence lengths you do start to see numerical errors creeping in, (ie torch all close fails).
Yup, maybe it will not work for pre-trained models, but can work for training new ones faster with simpler code. For educational purposed and small use-cases this is better than naive sequential scan.
I've created a repository that implements Mamba in JAX with associative scans. It isn't feature-complete yet or 100% verified to be correct, but seems quite performant. It still lags behind the reference PyTorch+CUDA implementation in speed, however, so I will explore a Pallas kernel next in this repository.
It may be of interest to those in this thread https://github.com/vvvm23/mamba-jax for your own implementations.
Just an update on the pallas angle, it seems that triton doesn't support a multi-tensor associative scan (openai/triton#2657). So to do it in pallas you probably need to implement it yourself, and I don't know how that will compare on efficiency. pl.load
doesn't support strided slicing, and once you've loaded from a ref it doesn't look like you can slice at all, so I'm not sure if it's even possible atm. Take that all with a mound of salt since I'm just starting to look at pallas.
Bummer. Maybe it is worth trying to merge the tensors and paying the memory overhead just to see if it works at all in triton.
Just saw this pytorch version as well https://github.com/alxndrTL/mamba.py which seems to have good perf. @alxndrTL
(I haven't made too much annotated-mamba progress, but I will likely have a documented pytorch pscan and a heinsen version. I'll include you guys as authors if that is okay. )
Think I got a Triton version working! passes my tests at least. (based on @proger 's trick here https://github.com/proger/accelerated-scan/blob/main/accelerated_scan/triton.py)
import torch
import triton
import triton.language as tl
@triton.jit
def unpack64(merged):
tl.static_assert(merged.dtype == tl.uint64)
b = (merged & 0xFFFFFFFF).to(tl.uint32).to(tl.float32, bitcast=True)
a = (merged >> 32).to(tl.uint32).to(tl.float32, bitcast=True)
return a, b
@triton.jit
def pack64(a, b):
tl.static_assert(a.dtype == tl.float32)
tl.static_assert(b.dtype == tl.float32)
a = a.to(dtype=tl.uint32, bitcast=True).to(tl.uint64)
a = a << 32
b = b.to(dtype=tl.uint32, bitcast=True).to(tl.uint64)
return a | b
@triton.jit
def first_order_op(l, r):
fl, xl = unpack64(l)
fr, xr = unpack64(r)
f = fl * fr
x = fr * xl + xr
return pack64(f, x)
@triton.jit
def mamba_kernel(A_ptr, B_ptr, C_ptr, d_ptr, x_ptr, y_ptr, B, L, D,
N:tl.constexpr, L_SIZE: tl.constexpr, D_SIZE: tl.constexpr):
# A_ptr: D x N
# B_ptr, C_ptr: B x L x N
# d_ptr: B x D x L
# x, y: B x D x L
# Loads
Ls, Ns, Ds = tl.arange(0, L_SIZE), tl.arange(0, N)[:, None], tl.arange(0, D_SIZE)[:, None]
B_idx, D_idx = tl.program_id(0), tl.program_id(1) * D_SIZE
d = tl.load(d_ptr + L * (D_idx + Ds) + D * L * B_idx + Ls) # D x L
A = tl.load(A_ptr + N * D_idx + Ns) # N x 1
B = tl.load(B_ptr + L * N * B_idx + N * Ls + Ns) # N x L
C = tl.load(C_ptr + L * N * B_idx + N * Ls + Ns) # N x L
x = tl.load(x_ptr + L * D * B_idx + L * (D_idx + Ds) + Ls) # D x L
# Discretize
A = A[:, :, None]
A_bar = tl.exp(d * A)
B_bar = ((A_bar - A) / A) * B[:, None, :]
h = B_bar * x # N x D x L
# Scan
stack = pack64(A_bar, h)
h = tl.associative_scan(stack, 2, first_order_op)
_, h = unpack64(h)
y = tl.sum(C[:, None, :] * h, 0) # D x L
# Write out
tl.store(y_ptr + L * D * 2 * B_idx + L * (D_idx + Ds) + Ls, y)
The packing method is way smarter than then weird hacks I was trying to use to get it working with pallas, I'll see if I can make this work with JAX as well.
haha, yeah I tried some crazy hacks too. this one feels pretty clean and should work nicely for backwards and f16 as well.
I'm guessing there's no way to beat Triton's associative scan, since the only ways I could accomplish strided slices and updates involved writes to global memory.
Nice work! Might give this a shot porting to Pallas and see if a similar method can work for TPU backends too
I got a Pallas kernel working using the packing trick. Unfortunately, it's still losing to jax.lax.associative_scan
once the sequence length gets long enough. These numbers are from a Turing GPU, the results might be different on more recent hardware.
Could you share code for this? I can run it on a more modern GPU and see what the results are
Here's the code, it's a port of the mamba-minimal code, but with the sequence dimension at the end. Making the block size or sequence length larger leads to significantly increased compilation times, but I haven't dug into why yet.
from functools import partial
import jax
from jax.core import Primitive, ShapedArray
import jax.numpy as jnp
import jax.experimental.pallas as pl
from jax._src.pallas.triton.lowering import triton_lowering_rules, _associative_scan_lowering
import triton.language as tl
# ------ Define lowering rules for casting ------
type_map = {
jnp.float16.dtype: tl.float16,
jnp.bfloat16.dtype: tl.bfloat16,
jnp.float32.dtype: tl.float32,
jnp.float64.dtype: tl.float64,
jnp.uint8.dtype: tl.uint8,
jnp.uint16.dtype: tl.uint16,
jnp.uint32.dtype: tl.uint32,
jnp.uint64.dtype: tl.uint64,
jnp.int8.dtype: tl.int8,
jnp.int16.dtype: tl.int16,
jnp.int32.dtype: tl.int32,
jnp.int64.dtype: tl.int64,
}
def bitcast_convert_type_lowering(ctx, arg, new_dtype):
out_dtype = type_map[new_dtype]
return arg.to(out_dtype, bitcast=True, _builder=ctx.builder)
triton_lowering_rules[jax.lax.bitcast_convert_type_p] = bitcast_convert_type_lowering
def convert_element_type_lowering(ctx, arg, new_dtype, weak_type):
out_dtype = type_map[new_dtype]
return arg.to(out_dtype, _builder=ctx.builder)
triton_lowering_rules[jax.lax.convert_element_type_p] = convert_element_type_lowering
# ------ Packing ------
# Compiling these individually as in Sasha's code may be better
def unpack(val):
assert val.dtype == jnp.uint64
mask = jnp.array(0xFFFFFFFF, dtype=jnp.uint64)
a = (val & mask).astype(jnp.uint32)
a = jax.lax.bitcast_convert_type(
a,
jnp.float32
)
b = (val >> 32).astype(jnp.uint32)
b = jax.lax.bitcast_convert_type(
b,
jnp.float32
)
return a, b
def pack(a, b):
assert a.shape == b.shape
assert a.dtype == jnp.float32
assert b.dtype == jnp.float32
a = jax.lax.bitcast_convert_type(a, jnp.uint32).astype(jnp.uint64)
b = jax.lax.bitcast_convert_type(b, jnp.uint32).astype(jnp.uint64)
assert a.dtype == jnp.uint64
assert b.dtype == jnp.uint64
result = (b << 32) | a
assert result.dtype == jnp.uint64
return result
# ------ Scan primitive ------
mamba_scan_p = Primitive('mamba_scan')
def mamba_scan_impl(x):
raise NotImplementedError
def mamba_scan_abstract_eval(x):
return ShapedArray(x.shape, x.dtype)
mamba_scan_p.def_impl(mamba_scan_impl)
mamba_scan_p.def_abstract_eval(mamba_scan_abstract_eval)
def scan_op(l, r):
a1, b1 = unpack(l)
a2, b2 = unpack(r)
result = (a1 * a2, a2 * b1 + b2)
return pack(*result)
def scan_lowering(ctx, arg):
# I'm not sure why I need to use `enable_x64` again,
# but removing it causes a crash
with jax.experimental.enable_x64():
return _associative_scan_lowering(
scan_op, ctx, arg, (1,)
)[0]
triton_lowering_rules[mamba_scan_p] = scan_lowering
# ------ Main kernel ------
def mamba_kernel(u, delta, A, B, C, D, out_ref):
"""
u: (d_inner, seq)
delta: (d_inner, seq)
A: (d_inner, channels)
B: (channels, seq)
C: (channels, seq)
D: (d_inner,)
out_ref: (d_inner, seq)
"""
delta = delta[...][:, None, :]
u = u[...]
deltaA = jnp.exp(delta * A[...][:, :, None])
deltaBu = delta * B[...][None, :, :] * u[:, None, :]
packed = pack(deltaA, deltaBu)
result = mamba_scan_p.bind(packed)
scan_output = unpack(result)[1]
y = jnp.sum(scan_output * C[...][None, :, :], axis=1)
out_ref[...] = y + u * D[...][:, None]
# ------ Interface ------
@partial(jax.jit, static_argnames=('block_size',))
def selective_scan(u, delta, A, B, C, D, block_size=8):
d_inner, seq = u.shape
out_struct = jax.ShapeDtypeStruct(
shape=u.shape,
dtype=delta.dtype,
)
return pl.pallas_call(
mamba_kernel,
out_shape=out_struct,
grid=(d_inner // block_size,),
in_specs=[
pl.BlockSpec(lambda i: (i, 0), (block_size, u.shape[1])),
pl.BlockSpec(lambda i: (i, 0), (block_size, delta.shape[1])),
pl.BlockSpec(lambda i: (i, 0), (block_size, A.shape[1])),
pl.BlockSpec(lambda i: (0, 0), B.shape),
pl.BlockSpec(lambda i: (0, 0), C.shape),
pl.BlockSpec(lambda i: (i,), (block_size,)),
],
out_specs=pl.BlockSpec(lambda i: (i, 0), (block_size, u.shape[1])),
)(u, delta, A, B, C, D)
Very neat that it runs. Will be interesting to understand why it is slower.
I'm working on the backwards implementation which is a bit more annoying (can pallas autodiff your version?). Anyone know how to flip/reverse a matrix in triton? Do I have to make a matrix to do it
Do you think we need to block over length?
I was wondering about this as well. One possibility is to do blocks sequentially and pass the final value from the previous block into the kernel. It's possible that this is effectively what jax.lax.associative_scan
ends up doing with its various fused blocks (see below).
Do you know what associative scan does
I'm not sure what XLA will do in general, but on my system it creates a different number of fused ops depending on what I use for the sequence length. Here's some code to inspect the compiled scan:
Here are the compiled outputs:
If I kick the length up to 4096 there are a ton of separate fused blocks, so if you need to write to global memory to transition between those blocks, this will be reading and writing to HBM a lot.
can pallas autodiff your version?
No since it won't know how to do a backward pass for the mamba_scan_p
primitive that I added. There's also the fact that the packing isn't differentiable.
Okay, fun. Let do a version that blocks on length and saves the final h state for each block.
Also seems flipping a matrix is currently pretty hard to do in triton, so might need a trick there.
Wish I didn't have a ton of real work to do.
Let do a version that blocks on length and saves the final h state for each block.
By the way, I did end up giving this a shot, but it didn't seem obviously possible to implement a carry as would be necessary.
Oh sorry, I got distracted trying to get autodiff in triton working https://github.com/srush/triton-autodiff . (Obviously Pallas is better, but I can't figure out how to install it. )
What do you think about three 3 step approach to a carry? I was able to get this to work for a simple scan.
@triton.jit
def scan(A_bar, h, h1, h2, L_SIZE: tl.constexpr):
stack = pack64(A_bar, h)
h_init = pack64(h1, h2)
init_stack = first_order_op(h_init, stack)
stack = tl.where(tl.arange(0, L_SIZE) == 0, init_stack, stack)
h = tl.associative_scan(stack, -1, first_order_op)
h1, h2 = unpack64(h)
return h1, h2
@triton.jit
def reduce(A_bar, h):
stack = pack64(A_bar, h)
h = tl.reduce(stack, -1, first_order_op)
return unpack64(h)
@triton.jit
def step1_kernel(A_ptr, B_ptr, h1_ptr, h2_ptr, L_BLOCKSIZE: tl.constexpr):
section = tl.program_id(0)
section_start = tl.program_id(0) * L_BLOCKSIZE
Ls = section_start + tl.arange(0, L_BLOCKSIZE)
A, B = tl.load(A_ptr + Ls), tl.load(B_ptr + Ls)
h1, h2 = reduce(A, B)
tl.store(h1_ptr + section + 1, h1)
tl.store(h2_ptr + section + 1, h2)
@triton.jit
def step2_kernel(h1_ptr, h2_ptr, L_BLOCKSIZE: tl.constexpr):
Ls = tl.arange(0, L_BLOCKSIZE)
h1, h2 = tl.load(h1_ptr + Ls), tl.load(h2_ptr + Ls)
h1, h2 = scan(h1, h2, h1, h2, L_SIZE=L_BLOCKSIZE)
tl.store(h1_ptr + Ls, h1)
tl.store(h2_ptr + Ls, h2)
@triton.jit
def step3_kernel(A_ptr, B_ptr, h1_ptr, h2_ptr, y_ptr, L_BLOCKSIZE: tl.constexpr):
section = tl.program_id(0)
section_start = tl.program_id(0) * L_BLOCKSIZE
Ls = section_start + tl.arange(0, L_BLOCKSIZE)
A, B = tl.load(A_ptr + Ls), tl.load(B_ptr + Ls)
h1 = tl.load(h1_ptr + section)
h2 = tl.load(h2_ptr + section)
_, y = scan(A, B, h1, h2, L_SIZE=L_BLOCKSIZE)
tl.store(y_ptr + Ls, y)
Oh yep that'll do it. My anti code duplication impulses did me a disservice here since I was trying to figure out how to get a single scan setup to work for the different cases. Nice!
So I think this paper https://arxiv.org/abs/2402.19427 kind of answers this thread. It sounds like the right way to do this in Jax is simply as a linear scan. I think that means it is really just a basic scan loop in pallas.
Kind of a bummer, I was really hoping the associative scan would work out nicely here.
I saw that paper, but it sounded like it might have come down to how the performance of memory access on TPUs. On GPUs I think we don't know what's best for this architecture yet, right?
That's fair, but I found the discussion in this paper relatively convincing. These models are big enough that the compute is dwarfed by the memory loads, so the fact that associative scan has better compute is washed out.
If I find time I will implement a simple for-loop scan in triton to compare with my complex associative scan.
btw, many of the other issues we discussed in this thread are now fixed in Triton 3.0. You don't need packing any more for the associative scan, there is now a way to do np.roll, and you can reverse associative scan. (not that it helps yet)
Didn't know about the no-packing change, that's great! I did see you put in the other changes though.
I think this can be very helpful: https://x.com/mkinyugo/status/1740430275375714336?s=46. Parallel scan can be implemented in a way proposed here https://github.com/glassroom/heinsen_sequence, with pure pytorch and with GPU benefits.
I'm the author of the https://github.com/johnryan465/pscan repo mentioned above and have benchmarked and tested a few different approaches including that one. It does work but at long sequence lengths you do start to see numerical errors creeping in, (ie torch all close fails).
I found that my one of my early attempt trying to write mamba in pure torch is quite similar to yours. (Though I have to admit I am later, I derived that approach independently.)
I deal with the "numerical errors" with chunks, but the chunk size should be less than 64 or 32 in practice to keep numerical stability. And for the speed, it is faster than "mamba-minimal" and "mamba.py", but is quite slower than "mamba".
Then I tried writing this in "torch.autograd.Function", but the speed almost has no change, I tried triton also, but as I am a beginner in triton, I did not finish that.
I also have some other ideas about chunk-wise approaches, but that seems that works good theoritically, but not practically.
So now I am here searching for a way that is more general but with nice speed compared to mamba.
There is a minimal version in pytorch: https://github.com/johnma2006/mamba-minimal
Looking forward to a minimal version in JAX.
Then maybe we can try to accelerate it? (Not an expert in CUDA, is it possible to accelerate using mamba's original cuda code? )