Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
12.47k stars 1.11k forks source link

Support for Left Padding Mask KV? #649

Open aciddelgado opened 8 months ago

aciddelgado commented 8 months ago

Are there plans or a way to support left padding kv attention mask? I believe right padding can be supported with the mha_fwd_kvcache api with the seqlensk pointer, but will there be a similar option for left padding?

tridao commented 8 months ago

I haven't used left-padding. What's the use case of left padding instead of right padding the kv cache?

aciddelgado commented 8 months ago

@tridao We are trying to support as many different formats as possible... In our case, a lot of models are trained with left-padding and it's useful to support it directly at a kernel level. Is there plans for left-padding support or general masking? Thank you!

turboderp commented 6 months ago

I'm struggling with this as well. Consider:

seq 0 1 2 3 4 5
0 Hello
1 Once upon a time ,

Here we can't really do any batching at all because the sequences don't line up. We could produce four tokens for seq 0 first, then begin the batched inference after that, or we could start batching right away but discard the results for seq 1 until we reach position 5. Either approach is wasteful compared to left-padding:

seq 0 1 2 3 4 5
0 Hello
1 Once upon a time ,

Now you can sample token 5 for both seqs in one forward pass immediately. The tradeoffs are the wasted inference on the padding tokens during prompt ingestion, and the extra VRAM allocated to the masked keys/values. Whether those are good tradeoffs depends on the circumstances.

One thing I'm working on at the moment is classifier-free guidance, where two prompts of roughly equal length (but maybe differing sentiment) are evaluated in parallel to sample one token from a mix of the two sets of logits. Right-padding simply doesn't work for that. Unpadding could work, but it's also wasteful since it requires reshaping the entire K/V cache once per token.

If there were some way to modulate the attention weights before the softmax, that would unlock not just left-padding but some neat opportunities in speculative decoding as well.