vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
31.07k stars 4.72k forks source link

[Feature]: Support attention backend with FlexAttention #7315

Open mgoin opened 3 months ago

mgoin commented 3 months ago

🚀 The feature, motivation and pitch

FlexAttention was proposed as a performant attention implementation leveraging torch.compile with easy APIs for adding support for complex attention variants such as Causal, Relative Positional Embeddings, Alibi, Sliding Window Attention, PrefixLM, Document Masking/Sample Packing/Jagged Tensors, Tanh Soft-Capping, PagedAttention, etc.

https://pytorch.org/blog/flexattention/

While it is not the fastest attention backend (yet!) it is clearly performant enough while enabling much more flexibility than current compiled backends to easily implement attention features we need for crucial models, like Soft-capping in Gemma 2 which we currently rely on FlashInfer for. Not to mention it is a first-class citizen for torch.compile.

The current blocker is it will not be available until PyTorch 2.5.0.

image

Alternatives

No response

Additional context

No response

yzh119 commented 3 months ago

Actually the core idea of FlashInfer is a general block-sparse flashattention implementation (we will release our paper soon) which is very similar to FlexAttention, and a runtime scheduler for load balancing/wave quantization.

I love the idea of native support of torch.compile and I'm open to see possible collaborations (use flexattention in flashinfer and contribute some of flashinfer's idea to flexattention), and I believe we have the common goal of making LLM serving systems easier to use and faster, feel free to loop me in the conversation.

Chillee commented 3 months ago

@yzh119 I'd note that imo, the block-sparse attention part certainly isn't new, and isn't the primary contribution of FlexAttention. FlashAttention1 already had a block-sparse flashattention kernel, xformers had one as well, the Jax folk also implemented one in SplashAttention, and if you squint your eyes, pagedattention is also basically a block-sparse attention kernel.

I think the crucial missing piece that FlexAttention provides, however, is that by itself, you basically can't implement any attention variants with a block-sparse attention kernel. Even with just a causal mask, you can get 90% of the way there with a block-sparse attention kernel... but what about the boundaries?

This is where torch.compile can help (and what FlexAttention leverages). Generating the masking function is actually fairly trivial from a codegen perspective, but is quite difficult to do without a compiler like torch.compile in the mix.

So, my view is that the FlexAttention API abstraction and codegen combined is what truly unlocks the flexibility of block-sparse flashattention kernels :)

yzh119 commented 3 months ago

Reply to @Chillee :

the block-sparse attention part certainly isn't new

Of course it's not new, but you should note that flashattention repo stopped block sparse support at FA2 and all these repos requires a large block size while flashinfer supports any block size, we spent a lot of effort on that.

and if you squint your eyes, pagedattention is also basically a block-sparse attention kernel.

That's exactly how FlashInfer implements page attention, from the beginning.

my view is that the FlexAttention API abstraction and codegen combined is what truly unlocks the flexibility of block-sparse flashattention kernels :)

I agree with this point and I'm also working on similar things (codegen and customization), that's why I ask you if you are interested in collaborating, e.g. porting some of our optimizations on any block size to yours, adding flashinfer's blocksparsee implementation to one of flexattention backends (which I'd be happy to see) or whatever.

Tomorrowdawn commented 1 month ago

I believe using torch.compile to support attention is absolutely the right approach. I think the main issues with the flexattention abstraction are centered on sparse attention. A significant problem is that, assuming (because they indicated this part will be released next time) flexattention's approach to handling paged attention is to compress the batch size dimension onto the seq dimension, adding a mask to differentiate (like xformer does), this kv cache layout cannot grow. In other words, you cannot 'append' the kv cache corresponding to the next token to this layout. Admittedly, this can be solved by changing the mask, but

  1. The mask needs to be changed with each generation
  2. Memory is highly non-contiguous (because all next tokens can only be stored in an interleaved manner).

I think flexattention needs to carefully reconsider the level of abstraction here.

Chillee commented 1 month ago

In other words, you cannot 'append' the kv cache corresponding to the next token to this layout.

This is not true. Fundamentally, from FlexAttention's perspective, there are two (independent) relevant components here.

  1. How is the data laid out in memory?
  2. How is the sparsity expressed?

The data itself can be laid out however you like. For example, for a non-paged kv-cache, an easy way to represent it would be

static_kv_cache: [Batch_size, num_heads, MAX_TOKENS, head_dim]
num_tokens: [Batch_size]
def mask_mod(b, h, q_idx, kv_idx):
    return kv_idx < num_tokens[b]

And in this case, it's trivial to "append" a new token to the kv-cache (e.g. static_kv_cache[B, cur_token] = new_kv_value). But, you can also imagine many other layouts :)

The second question is - how do you update the block-sparsity mask with each new generation? Here again, there are various options.

  1. You can simply recompute your block-sparsity mask using create_block_mask. Admittedly, like you mention, this is somewhat expensive, but there are a lot of ways to ameliorate this. For example, the block-sparsity mask is not required for semantic information, but merely for efficiency. So, for example, one strategy you could take is: "Update your block-sparsity mask once every 128 tokens, and use score_mod to mask out the rest".
  2. You can mutate your BlockMask data structure. For example, for the static kv cache strategy, you could do something like
    block_mask.num_kv_blocks[torch.arange(NUM_BATCHES)] += (input_pos + 1 ) % 128 == 0 # adds one block if current input position is about to require a new block
    # Also update `kv_indices`
  3. You can "index" into an existing BlockMask at the right position to get the BlockMask for your current query position. This is the strategy we do here: https://github.com/pytorch-labs/gpt-fast/pull/196/files#diff-801d46835a8aee4eba56246cd21be7ae8628884a17b2648ec6afc8786e319af8R76 This is both efficient (operates only on block-sparse data structure) and also generic (will work for arbitrary masking scheme).

Memory is highly non-contiguous (because all next tokens can only be stored in an interleaved manner).

Side note - this is not actually that bad, because you have the head_dim, which means that even if you're loading your tokens in an arbitrary interleaved manner, you still get coalesced reads. And in fact, i believe this is what flashinfer does! (@yzh119 can correct me if I'm wrong)

Tomorrowdawn commented 1 month ago

Thank you for your explanation. It seems I need to clarify my statement and provide a specific example.

In your static example, obviously, due to sentences of varying lengths, a lot of memory would be wasted, and PagedAttention was proposed precisely to solve this waste. The approach I mentioned above refers to using a tensor of shape [1, num_heads, total_tokens, hidden_dim], like:

kv cache: [1, num_heads, total_tokens, hidden_dim]
mask:
def document_masking(b, h, q_idx, kv_idx):
    return document_id[q_idx] == document_id[kv_idx]

I have no objections regarding prefill; my main question is about the decoding phase. In this case, to 'append' the kv cache, you can only add the kv cache corresponding to [batch_size] tokens (from different sentences) along the total_tokens dimension, and you need to modify the corresponding mask for this:

def append_new(new_kv:[1, ..., batch_size, ...], cur_len:int):
     kv_cache[:, :, cur_len: cur_len+batch_size, :] = new_kv
     document_id[cur_len+arange(batch_size)] = arange(batch_size)

Wouldn't this dynamic modification of the mask cause issues for torch.compile? I previously thought that the key to torch.compile's optimization was it can know at compile-time which positions don't need to be calculated. I find it hard to imagine that checking the mask for each position at runtime would be efficient. But if it's dynamically modified, the information that torch.compile can utilize during compilation would be greatly reduced. This is somewhat beyond my knowledge of compilation. If what I'm saying is incorrect, please correct me.

Chillee commented 1 month ago

Wouldn't this dynamic modification of the mask cause issues for torch.compile?

No, this isn't an issue. A lot of the usage of FlexAttention involves masks that change on every iteration.

I previously thought that the key to torch.compile's optimization was it can know at compile-time which positions don't need to be calculated. I find it hard to imagine that checking the mask for each position at runtime would be efficient.

Hm... in this case, the key to good performance with sparsity is blocksparsity, which can be checked at runtime efficiently. In some sense, you can imagine that the "bulk" of FlexAttention's performance comes from a block-sparse flashattention implementation (that doesn't require torch.compile at all - we wrote it in Triton, but it could also be written in say, CUTLASS). IMO, the main insight of FlexAttention is that if you combine this block-sparse flashattention implementation with simple codegen, you can get a very flexible and performant attention implementation.

Tomorrowdawn commented 1 month ago

Excellent work beyond my imagination! I fully understand now; it seems it's time to 'talk is cheap, write some code' :D

In some sense, you can imagine that the "bulk" of FlexAttention's performance comes from a block-sparse flashattention implementation (that doesn't require torch.compile at all - we wrote it in Triton, but it could also be written in say, CUTLASS)

Looking forward to your next blog post! Perhaps you could describe the implementation of FlexAttention a bit there, as it seems like magic from just the existing blog. Attention is a somewhat low-level tool, and for its users, using black-box implementations often leads to excessive worry (just like I did before XD)

I'm very interested in rewriting vllm backend using FlexAttention, but since vllm involves speculative decoding and paged memory, creating a completely transparent backend might not be that easy to accomplish.