Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
11.83k stars 1.05k forks source link

Any plans to support tree attention mask? #924

Open KexinFeng opened 2 months ago

KexinFeng commented 2 months ago

Tree attention mask is already supported in huggingface/transformers: https://github.com/huggingface/transformers/pull/27539 It will be very helpful for the speculative decoding applications. More sepcifically, in flash_attn/flash_attn_interface.py#flash_attn_with_kvcache, the tree attention mask will need to be specified and passed in as an argument.

Do you have any near plans to support it?

Thanks

Related questions: https://github.com/Dao-AILab/flash-attention/issues/840, https://github.com/Dao-AILab/flash-attention/issues/918

tridao commented 2 months ago

Sure, we'll just need someone to contribute :D

thorinf commented 2 months ago

I'm keen to try supporting a generic mask case, like [B, Q, K] bool, and doing conditional execution. Ideally this covers quite a lot of masking cases, but I guess optimised kernels would work better for more structured masks (like Tree).

KexinFeng commented 1 month ago

I don't see much difference between a generic mask and a structured mask. For a tree mask, the mask argument would also be of [B, K, Q]. In the 4d attention mask mentioned above, it's nothing but [b, h, k, q] h being number of head.

If you are able to implement a generic mask, then a structured mask will be ready

thorinf commented 1 month ago

What I mean is that for a structured mask you don't necessarily have to create a bool tensor. In the casual case it can be hardcoded in the kernel to ignore j>i+k_cache, which saves a little bit of memory. If its structured the locations you'll visit are predictable.

KexinFeng commented 1 month ago

I see. Yeah, in the causal mask case, indeed the bool tensor mask argument is not required. For the tree attention mask, however, this argument will be inevitable. But probably this doesn't increase much implementation complexity, since the causal mask will internally be converted to such tensor anyway. @thorinf Look forward to your PR!

jkobject commented 1 month ago

Hello, sorry for the naive question but:

  1. Why do you need structured masking? can't you do something similar with attention biases?
  2. Are you hoping that you might be able to skip blocks that are entirely masked? or will you still compute attention over the full matrix?

It might help me understand this a bit more :)