Open mgoin opened 1 month ago
Actually the core idea of FlashInfer is a general block-sparse flashattention implementation (we will release our paper soon) which is very similar to FlexAttention, and a runtime scheduler for load balancing/wave quantization.
I love the idea of native support of torch.compile
and I'm open to see possible collaborations (use flexattention in flashinfer and contribute some of flashinfer's idea to flexattention), and I believe we have the common goal of making LLM serving systems easier to use and faster, feel free to loop me in the conversation.
@yzh119 I'd note that imo, the block-sparse attention part certainly isn't new, and isn't the primary contribution of FlexAttention. FlashAttention1 already had a block-sparse flashattention kernel, xformers had one as well, the Jax folk also implemented one in SplashAttention, and if you squint your eyes, pagedattention is also basically a block-sparse attention kernel.
I think the crucial missing piece that FlexAttention provides, however, is that by itself, you basically can't implement any attention variants with a block-sparse attention kernel. Even with just a causal mask, you can get 90% of the way there with a block-sparse attention kernel... but what about the boundaries?
This is where torch.compile
can help (and what FlexAttention leverages). Generating the masking function is actually fairly trivial from a codegen perspective, but is quite difficult to do without a compiler like torch.compile
in the mix.
So, my view is that the FlexAttention API abstraction and codegen combined is what truly unlocks the flexibility of block-sparse flashattention kernels :)
Reply to @Chillee :
the block-sparse attention part certainly isn't new
Of course it's not new, but you should note that flashattention repo stopped block sparse support at FA2 and all these repos requires a large block size while flashinfer supports any block size, we spent a lot of effort on that.
and if you squint your eyes, pagedattention is also basically a block-sparse attention kernel.
That's exactly how FlashInfer implements page attention, from the beginning.
my view is that the FlexAttention API abstraction and codegen combined is what truly unlocks the flexibility of block-sparse flashattention kernels :)
I agree with this point and I'm also working on similar things (codegen and customization), that's why I ask you if you are interested in collaborating, e.g. porting some of our optimizations on any block size to yours, adding flashinfer's blocksparsee implementation to one of flexattention backends (which I'd be happy to see) or whatever.
🚀 The feature, motivation and pitch
FlexAttention was proposed as a performant attention implementation leveraging
torch.compile
with easy APIs for adding support for complex attention variants such as Causal, Relative Positional Embeddings, Alibi, Sliding Window Attention, PrefixLM, Document Masking/Sample Packing/Jagged Tensors, Tanh Soft-Capping, PagedAttention, etc.https://pytorch.org/blog/flexattention/
While it is not the fastest attention backend (yet!) it is clearly performant enough while enabling much more flexibility than current compiled backends to easily implement attention features we need for crucial models, like Soft-capping in Gemma 2 which we currently rely on FlashInfer for. Not to mention it is a first-class citizen for
torch.compile
.The current blocker is it will not be available until PyTorch 2.5.0.
Alternatives
No response
Additional context
No response