vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
27.71k stars 4.09k forks source link

[Feature]: Support attention backend with FlexAttention #7315

Open mgoin opened 1 month ago

mgoin commented 1 month ago

🚀 The feature, motivation and pitch

FlexAttention was proposed as a performant attention implementation leveraging torch.compile with easy APIs for adding support for complex attention variants such as Causal, Relative Positional Embeddings, Alibi, Sliding Window Attention, PrefixLM, Document Masking/Sample Packing/Jagged Tensors, Tanh Soft-Capping, PagedAttention, etc.

https://pytorch.org/blog/flexattention/

While it is not the fastest attention backend (yet!) it is clearly performant enough while enabling much more flexibility than current compiled backends to easily implement attention features we need for crucial models, like Soft-capping in Gemma 2 which we currently rely on FlashInfer for. Not to mention it is a first-class citizen for torch.compile.

The current blocker is it will not be available until PyTorch 2.5.0.

image

Alternatives

No response

Additional context

No response

yzh119 commented 1 month ago

Actually the core idea of FlashInfer is a general block-sparse flashattention implementation (we will release our paper soon) which is very similar to FlexAttention, and a runtime scheduler for load balancing/wave quantization.

I love the idea of native support of torch.compile and I'm open to see possible collaborations (use flexattention in flashinfer and contribute some of flashinfer's idea to flexattention), and I believe we have the common goal of making LLM serving systems easier to use and faster, feel free to loop me in the conversation.

Chillee commented 1 month ago

@yzh119 I'd note that imo, the block-sparse attention part certainly isn't new, and isn't the primary contribution of FlexAttention. FlashAttention1 already had a block-sparse flashattention kernel, xformers had one as well, the Jax folk also implemented one in SplashAttention, and if you squint your eyes, pagedattention is also basically a block-sparse attention kernel.

I think the crucial missing piece that FlexAttention provides, however, is that by itself, you basically can't implement any attention variants with a block-sparse attention kernel. Even with just a causal mask, you can get 90% of the way there with a block-sparse attention kernel... but what about the boundaries?

This is where torch.compile can help (and what FlexAttention leverages). Generating the masking function is actually fairly trivial from a codegen perspective, but is quite difficult to do without a compiler like torch.compile in the mix.

So, my view is that the FlexAttention API abstraction and codegen combined is what truly unlocks the flexibility of block-sparse flashattention kernels :)

yzh119 commented 1 month ago

Reply to @Chillee :

the block-sparse attention part certainly isn't new

Of course it's not new, but you should note that flashattention repo stopped block sparse support at FA2 and all these repos requires a large block size while flashinfer supports any block size, we spent a lot of effort on that.

and if you squint your eyes, pagedattention is also basically a block-sparse attention kernel.

That's exactly how FlashInfer implements page attention, from the beginning.

my view is that the FlexAttention API abstraction and codegen combined is what truly unlocks the flexibility of block-sparse flashattention kernels :)

I agree with this point and I'm also working on similar things (codegen and customization), that's why I ask you if you are interested in collaborating, e.g. porting some of our optimizations on any block size to yours, adding flashinfer's blocksparsee implementation to one of flexattention backends (which I'd be happy to see) or whatever.