lucidrains / flash-attention-jax

Implementation of Flash Attention in Jax
MIT License
187 stars 23 forks source link

Performance benchmarks? #9

Open imoneoi opened 1 year ago

imoneoi commented 1 year ago

Are there any benchmark results now? Looking forward to performance comparisons with original attention, and official torch+CUDA implementation.

jakubMitura14 commented 1 year ago

I am also curious, additionally maybe it is possible to use cuda code with jax ?

https://github.com/dfm/extending-jax

OhadRubin commented 1 year ago

https://colab.research.google.com/drive/1-YCU9ps4gNuROJ3_8MLjSpbICGHaySxh?usp=sharing

jakubMitura14 commented 1 year ago

Fantastic! have you done experiment with the same data on original flash attention ?

OhadRubin commented 1 year ago

Not yet

jon-chuang commented 1 year ago

Hello, could I ask if this works with TPUs?

evanatyourservice commented 10 months ago

Here's an updated notebook that precompiles jit and blocks results until ready for anyone interested:

https://colab.research.google.com/drive/11QKRdgMtcivrJNmjTrf2bXTE5yXkXl_Z?usp=sharing

Looks like JAX compiles vanilla attention in a way to be faster than jax flash attention, so no need to change to flash attention if you use JAX.

SamuelGabriel commented 10 months ago

Wow this is open from almost a year ago...

I think someone could get a lot citations / clicks if they did a proper benchmark of transformer train/inference across platforms torch/jax GPU/TPU with standard tricks. I would cite you straight away for sure. It would also just be nice to settle this dispute that Google employees seem to have with everyone else, whether or not Jax is meaningfully more efficient or not (might just be up to TPU or GPU!?).

niemiaszek commented 9 months ago

Wow this is open from almost a year ago...

I think someone could get a lot citations / clicks if they did a proper benchmark of transformer train/inference across platforms torch/jax GPU/TPU with standard tricks. I would cite you straight away for sure. It would also just be nice to settle this dispute that Google employees seem to have with everyone else, whether or not Jax is meaningfully more efficient or not (might just be up to TPU or GPU!?).

Would be definitely nice to see such benchmark, but I can imagine how hard is comparing JAX vs PyTorch (GPU/TPU), with many optimized implementations for each device. For PyTorch with GPU we have Triton/CUDA, but JAX recently has also added Triton-like mechanism for writing custom Kernels with GPU/TPU - Pallas. You can even find implementation of attention in it here.

evanatyourservice commented 9 months ago

@niemiaszek I just recently saw they named and added docs for pallas, looks very interesting. JAX is also improving our ability to customize how networks are sharded across accelerators and are publishing papers on their results wrt efficiency, pretty cool I think. Unfortunately I don't have time to do a fair comparison between torch and jax with attention but it seems that whoever takes the time to delve into it, especially jax's recent improvements, would certainly benefit if they have a need.

Even if we don't take the time, it looks like the jax team continually adds their efficiency findings into jax as defaults so we don't have to implement ourselves.

lucidrains commented 9 months ago

from what i've heard, flash attention doesn't work well on TPUs, but i haven't kept up with the latest iteration of their chip design.

Pallas is just a wrapper around Triton, developed at OpenAI for GPUs. you will basically be always limited by what the Triton compiler can do

lucidrains commented 9 months ago

while this is a hard pill to swallow, i think existence of flash attention is a clear victory for having finely controlled GPGPU programming.

evanatyourservice commented 9 months ago

@lucidrains I'd agree as far as single-device optimizations go. I solely use jax because my work deals mainly with RL and I've already built everything out, but for things like language and vision models, resources like xformers are hard to beat. I do like jax's work toward multi-device customization especially from an RL perspective.

jon-chuang commented 9 months ago

while this is a hard pill to swallow, i think existence of flash attention is a clear victory for having finely controlled GPGPU programming.

Well, I would argue that in this day, that's no longer such a hard pill given the wide adoption of tiled programming paradigm like Triton (e.g. PyTorch - both codegen + incoming custom kernels, JAX - e.g. Pallas, hardware vendors including NVIDIA, AMD, Intel) which greatly reduces the effort and complexity of getting SOTA perf on GPUs.

lucidrains commented 9 months ago

@jon-chuang hmm, still a bit early to declare that imho

we'll see, i hope so!

jon-chuang commented 9 months ago

Yes, Triton is still not 100% (some matmul kernel size and certain kernels like flash attention backwards are still not SOTA). But it's certainly the direction that industry is investing in, and IMO it's good news for developers and tinkerers who want hackability of each layer of the stack.

I've already heard of some success stories with customizing flash attention kernels via Triton.

jon-chuang commented 9 months ago

I think these newish attention replacements will take time to be adopted particularly because the dust has not settled on them and it takes a while for wide-scale experimentation and large-scale training with them to truly prove them out.

IMO all it takes is a leap for a highly-funded industrial lab to go out on a limb and train an LLM with one of these...

For instance, Mistral AI essentially has a linear cost attention mechanism based on SWA - sliding window attention - one could argue of course how effective it is at truly capturing information across long context.

all these frameworks cannot do.

I think this is an overstatement? I think it simply has not been tried out in Triton yet. But it should not be that hard. But whether the performance matches is an open question.

I just hope that more devs become aware of how powerful triton is so that there's more experimentation with implementing these kernels.

lucidrains commented 9 months ago

@jon-chuang yea, let us just agree that we both wish for Triton and the like to succeed so us non-CUDA experts can have control over the entire stack

i just know it isn't there yet.

jon-chuang commented 9 months ago

Interestingly, a basic building block for Mamba (associative scan) already has support in Triton: https://github.com/pytorch/pytorch/issues/95408#issuecomment-1653748896

lucidrains commented 9 months ago

it doesn't support multiple inputs. also i heard it is still buggy in its current state

lucidrains commented 9 months ago

@jon-chuang anyways, let us take the discussion elsewhere, as this is about flash attention

MasterSkepticista commented 1 month ago

Flash attention is now available in jax-nightly with a cudnn implementation: jax.nn.dot_product_attention. It only supports Ampere architecture and later.

Note that the default is xla.