patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.05k stars 136 forks source link

Mamba Block #649

Open Artur-Galstyan opened 7 months ago

Artur-Galstyan commented 7 months ago

Hi there, I had just translated the Mamba layer from here to Equinox. Would you accept a PR for this?

PS: To get the most out of Mamba, we'd need to write some Pallas code akin to FlashAttention, but this wouldn't be that. FWIW, on my 3090, both implementations (MHA vs Mamba) were the same speed.

patrick-kidger commented 7 months ago

Oh, this would be awesome to have. Looking at the implementation (which looks very clean btw), I'm thinking it would most make sense as a complete example in the documentation?

So thank you, I'd definitely like to take a PR on this.

Indeed speed-wise it'd be best to have some custom Pallas kernel. (That'd be a fun extension to this some day.) Do you know what the speed difference is between this implementation and the original?

Artur-Galstyan commented 7 months ago

Btw: I just double checked, there is a speed difference between the MHA and Mamba! But I'll need to thoroughly investigate this before making any more statements :D This also goes for the speed difference between this and the original. I'll let you know!

PS: I have no experience in writing Pallas code, so this would be a good learning opportunity.

patrick-kidger commented 7 months ago

Okay! Let me know :)

Artur-Galstyan commented 7 months ago

Okay, so at first training seemed to be 2x slower using Jax vs. the mamba-minimal implementation. Then I double checked it because I couldn't just accept that the Eqx version was so much slower.

Then I ran a different test in which I excluded the training and only tested the time it took to simply pass some input through the model.

I took the PyTorch dataloader out and simply generated a 3000 batches with each having a batch size of 64 (i.e. a list [b1, b2, ..., b3000] with each b having the shape batch_size seq_len) consisting of random numbers. No more PyTorch -> Numpy -> Jax conversion in the dataloader on the Jax side.

Now, the Jax version was really flying! Results are on an RTX 3090.

Jax version: 7.060956954956055 s for 3000 matrices at 64 x 8 Minimal Mamba version (PyTorch): 12.25408673286438 s for 3000 matrices at 64 x 8

See this for the jax version test and this for the pytorch version.

But I didn't manage to test the real Mamba version that uses the special Cuda kernels because I couldn't get it to run due to skill issues 😆

But it looks as though the Jax version should be fast enough to include it here as a first iteration.

Edit: another learning here is to not underestimate conversion times from PyTorch to Jax arrays :)

patrick-kidger commented 7 months ago

Great! Note that your benchmark scripts could still be improved a bit. For example given:

    start_time = time.time()
    mamba = eqx.filter_jit(mamba)
    for x, y in tqdm(zip(xs, ys)):
        eqx.filter_vmap(mamba)(x)

    print("", flush=True)

then you are also including the timings of: (a) compilation (b) that you have vmap-jit rather than jit-vmap (always have the JIT as the final transformation!) (c) any overhead from tqdm (d) flushing to stdout (generally a fairly slow thing to do)

Conversely, it also looks like the JAX version doesn't include the cost of moving arrays host->device, whilst the PyTorch version does.

A comparison to the original CUDA kernels would definitely be particularly interesting, but I don't know how tricky that would be to set up.

patrick-kidger commented 7 months ago

Also relevant: https://github.com/srush/annotated-mamba/issues/1