Open mgoin opened 3 months ago
Actually the core idea of FlashInfer is a general block-sparse flashattention implementation (we will release our paper soon) which is very similar to FlexAttention, and a runtime scheduler for load balancing/wave quantization.
I love the idea of native support of torch.compile
and I'm open to see possible collaborations (use flexattention in flashinfer and contribute some of flashinfer's idea to flexattention), and I believe we have the common goal of making LLM serving systems easier to use and faster, feel free to loop me in the conversation.
@yzh119 I'd note that imo, the block-sparse attention part certainly isn't new, and isn't the primary contribution of FlexAttention. FlashAttention1 already had a block-sparse flashattention kernel, xformers had one as well, the Jax folk also implemented one in SplashAttention, and if you squint your eyes, pagedattention is also basically a block-sparse attention kernel.
I think the crucial missing piece that FlexAttention provides, however, is that by itself, you basically can't implement any attention variants with a block-sparse attention kernel. Even with just a causal mask, you can get 90% of the way there with a block-sparse attention kernel... but what about the boundaries?
This is where torch.compile
can help (and what FlexAttention leverages). Generating the masking function is actually fairly trivial from a codegen perspective, but is quite difficult to do without a compiler like torch.compile
in the mix.
So, my view is that the FlexAttention API abstraction and codegen combined is what truly unlocks the flexibility of block-sparse flashattention kernels :)
Reply to @Chillee :
the block-sparse attention part certainly isn't new
Of course it's not new, but you should note that flashattention repo stopped block sparse support at FA2 and all these repos requires a large block size while flashinfer supports any block size, we spent a lot of effort on that.
and if you squint your eyes, pagedattention is also basically a block-sparse attention kernel.
That's exactly how FlashInfer implements page attention, from the beginning.
my view is that the FlexAttention API abstraction and codegen combined is what truly unlocks the flexibility of block-sparse flashattention kernels :)
I agree with this point and I'm also working on similar things (codegen and customization), that's why I ask you if you are interested in collaborating, e.g. porting some of our optimizations on any block size to yours, adding flashinfer's blocksparsee implementation to one of flexattention backends (which I'd be happy to see) or whatever.
I believe using torch.compile
to support attention is absolutely the right approach. I think the main issues with the flexattention
abstraction are centered on sparse attention. A significant problem is that, assuming (because they indicated this part will be released next time) flexattention
's approach to handling paged attention is to compress the batch size dimension onto the seq dimension, adding a mask to differentiate (like xformer
does), this kv cache layout cannot grow. In other words, you cannot 'append' the kv cache corresponding to the next token to this layout. Admittedly, this can be solved by changing the mask, but
I think flexattention
needs to carefully reconsider the level of abstraction here.
In other words, you cannot 'append' the kv cache corresponding to the next token to this layout.
This is not true. Fundamentally, from FlexAttention's perspective, there are two (independent) relevant components here.
The data itself can be laid out however you like. For example, for a non-paged kv-cache, an easy way to represent it would be
static_kv_cache: [Batch_size, num_heads, MAX_TOKENS, head_dim]
num_tokens: [Batch_size]
def mask_mod(b, h, q_idx, kv_idx):
return kv_idx < num_tokens[b]
And in this case, it's trivial to "append" a new token to the kv-cache (e.g. static_kv_cache[B, cur_token] = new_kv_value
). But, you can also imagine many other layouts :)
The second question is - how do you update the block-sparsity mask with each new generation? Here again, there are various options.
create_block_mask
. Admittedly, like you mention, this is somewhat expensive, but there are a lot of ways to ameliorate this. For example, the block-sparsity mask is not required for semantic information, but merely for efficiency. So, for example, one strategy you could take is: "Update your block-sparsity mask once every 128 tokens, and use score_mod
to mask out the rest".BlockMask
data structure. For example, for the static kv cache strategy, you could do something like
block_mask.num_kv_blocks[torch.arange(NUM_BATCHES)] += (input_pos + 1 ) % 128 == 0 # adds one block if current input position is about to require a new block
# Also update `kv_indices`
Memory is highly non-contiguous (because all next tokens can only be stored in an interleaved manner).
Side note - this is not actually that bad, because you have the head_dim
, which means that even if you're loading your tokens in an arbitrary interleaved manner, you still get coalesced reads. And in fact, i believe this is what flashinfer does! (@yzh119 can correct me if I'm wrong)
Thank you for your explanation. It seems I need to clarify my statement and provide a specific example.
In your static example, obviously, due to sentences of varying lengths, a lot of memory would be wasted, and PagedAttention was proposed precisely to solve this waste. The approach I mentioned above refers to using a tensor of shape [1, num_heads, total_tokens, hidden_dim]
, like:
kv cache: [1, num_heads, total_tokens, hidden_dim]
mask:
def document_masking(b, h, q_idx, kv_idx):
return document_id[q_idx] == document_id[kv_idx]
I have no objections regarding prefill; my main question is about the decoding phase. In this case, to 'append' the kv cache, you can only add the kv cache corresponding to [batch_size]
tokens (from different sentences) along the total_tokens
dimension, and you need to modify the corresponding mask for this:
def append_new(new_kv:[1, ..., batch_size, ...], cur_len:int):
kv_cache[:, :, cur_len: cur_len+batch_size, :] = new_kv
document_id[cur_len+arange(batch_size)] = arange(batch_size)
Wouldn't this dynamic modification of the mask cause issues for torch.compile
? I previously thought that the key to torch.compile
's optimization was it can know at compile-time which positions don't need to be calculated. I find it hard to imagine that checking the mask for each position at runtime would be efficient. But if it's dynamically modified, the information that torch.compile
can utilize during compilation would be greatly reduced. This is somewhat beyond my knowledge of compilation. If what I'm saying is incorrect, please correct me.
Wouldn't this dynamic modification of the mask cause issues for torch.compile?
No, this isn't an issue. A lot of the usage of FlexAttention involves masks that change on every iteration.
I previously thought that the key to torch.compile's optimization was it can know at compile-time which positions don't need to be calculated. I find it hard to imagine that checking the mask for each position at runtime would be efficient.
Hm... in this case, the key to good performance with sparsity is blocksparsity, which can be checked at runtime efficiently. In some sense, you can imagine that the "bulk" of FlexAttention's performance comes from a block-sparse flashattention implementation (that doesn't require torch.compile
at all - we wrote it in Triton, but it could also be written in say, CUTLASS). IMO, the main insight of FlexAttention is that if you combine this block-sparse flashattention implementation with simple codegen, you can get a very flexible and performant attention implementation.
Excellent work beyond my imagination! I fully understand now; it seems it's time to 'talk is cheap, write some code' :D
In some sense, you can imagine that the "bulk" of FlexAttention's performance comes from a block-sparse flashattention implementation (that doesn't require torch.compile at all - we wrote it in Triton, but it could also be written in say, CUTLASS)
Looking forward to your next blog post! Perhaps you could describe the implementation of FlexAttention a bit there, as it seems like magic from just the existing blog. Attention is a somewhat low-level tool, and for its users, using black-box implementations often leads to excessive worry (just like I did before XD)
I'm very interested in rewriting vllm backend using FlexAttention, but since vllm involves speculative decoding and paged memory, creating a completely transparent backend might not be that easy to accomplish.
🚀 The feature, motivation and pitch
FlexAttention was proposed as a performant attention implementation leveraging
torch.compile
with easy APIs for adding support for complex attention variants such as Causal, Relative Positional Embeddings, Alibi, Sliding Window Attention, PrefixLM, Document Masking/Sample Packing/Jagged Tensors, Tanh Soft-Capping, PagedAttention, etc.https://pytorch.org/blog/flexattention/
While it is not the fastest attention backend (yet!) it is clearly performant enough while enabling much more flexibility than current compiled backends to easily implement attention features we need for crucial models, like Soft-capping in Gemma 2 which we currently rely on FlashInfer for. Not to mention it is a first-class citizen for
torch.compile
.The current blocker is it will not be available until PyTorch 2.5.0.
Alternatives
No response
Additional context
No response