Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.96k stars 1.3k forks source link

[v2] Attention Masking #352

Open MikeynJerry opened 1 year ago

MikeynJerry commented 1 year ago

Is any plan to add attention masking support? PyTorch's version of flash attention v1 included the ability to provide an attention mask in their implementation and it would be very useful to have this feature in v2.

leizhao1234 commented 1 year ago

In fact, when you send an attention mask to PyTorch's implementation, flash attention didn't work.

balachandarsv commented 1 year ago

Yes, facing the same issue. @tridao Can you please take a look at this and respond when you are available?

tridao commented 1 year ago

Attention mask isn't supported (either in v1 or v2). I might implement it at some point but there are other priorities now.

PeterL1n commented 1 year ago

I thought masking is supported through flash_attn_varlen_func

https://github.com/Dao-AILab/flash-attention/blob/d30f2e1cd50185c98ed88c0684b4a603f15bee37/flash_attn/flash_attn_interface.py#L454C21-L454C21

zhipeng93 commented 1 year ago

I have tested v1.0.7 and v2.0.4. The result turns out that none of them supports attention mask ---

The results of A and B are different.

samvanstroud commented 1 year ago

This paper might be relevant: https://arxiv.org/abs/2306.01160.

There are several related issues:

I believe pytorch 2.1 will have a memory efficient attention implementation that supports arbitrary masks: https://github.com/pytorch/pytorch/issues/96099

defei-coder commented 1 year ago

@tridao Hello, I plan to add a bias mask in flashattention2. I noticed that in order to integrate the scale and add operations scale_apply_exp2 ,the scale is delayed until after the maximum value is calculated. I plan to support bias mask in the apply_mask_causal function, I think if a bias mask is supported, it seems that ffma optimization in scale_apply_exp2 can be cancelled. Using scale and bias can still benefit from FFMA, do you have any suggestions?

zhangyipin commented 11 months ago

flash_attn/flash_attn_triton.py support bias input you can use bias=-inf

wehos commented 8 months ago

flash_attn/flash_attn_triton.py support bias input you can use bias=-inf

This is a good point but the example itself is not working with pytorch2.0+ (<==triton2.0+) 😭

jaanli commented 7 months ago

Anyone have tips on custom masks with flash attention for training?

(I need this to train encoder-decoder models with variable-length sequences using non-causal masks.)

This came up in a recent article: https://www.yitay.net/blog/training-great-llms-entirely-from-ground-zero-in-the-wilderness

The other striking thing is how little support these codebases have for large scale encoder-decoder training or even prefixLM training. To that end, even flash attention has consistently declined to provide support for prefixLM training (i.e., custom masks) despite reasonable demand on their github issues for whatever reason.

Curious what this would take or if it is still out of scope for the flash attention library?

Really grateful that this exists!! Just posting for visibility in case others have solved this problem :)

tridao commented 7 months ago

Curious what this would take or if it is still out of scope for the flash attention library?

Not out of scope, it's just someone needs to go implement it :D

jaanli commented 7 months ago

Understood — thank you!! Will try using the varlen functions for now :)

ardagoreci commented 5 months ago

I was wondering if there has been any updates on this? AlphaFold3 uses a lot of attention pair biasing and it would be tremendously useful to computational biology if flash attention supported attention biasing!

tridao commented 5 months ago

I was wondering if there has been any updates on this? AlphaFold3 uses a lot of attention pair biasing and it would be tremendously useful to computational biology if flash attention supported attention biasing!

Right, we still need someone to implement it.

alexzhang13 commented 3 months ago

@tridao Was wondering, what needs to be done for this to be implemented (I'm assuming efficiently? otw it seems quite simple)

I need a similar feature (arbitrary attention masks) but I figured I might take a stab at just implementing it if it still needs to be done.

alexzhang13 commented 3 months ago

I've implemented a version of custom masking for FA2 in Triton: https://github.com/alexzhang13/flashattention2-custom-mask

It suffices for my use case, but if something comes up where it's necessary to touch the FA3 code I may re-visit this.

amyxlu commented 2 months ago

Attention mask isn't supported (either in v1 or v2). I might implement it at some point but there are other priorities now.

Seems like the FlashAttention class does take in a key_padding_mask argument in its forward method. What would be the difference between this and the attention mask to be implemented? Cc @tridao. Thanks!

tridao commented 2 months ago

As you can see in the code, key_padding_mask just removes elements from keys and values before passing to the flash attention kernel. There's no attention mask passed to the kernel.

krejciadam commented 2 weeks ago

As you can see in the code, key_padding_mask just removes elements from keys and values before passing to the flash attention kernel. There's no attention mask passed to the kernel.

Is there any plan to support key_padding_mask in MHA in v2 ? My understanding is that this was supported in v1 (in flash_attn.flash_attention.FlashMHA), but in v2, one can only use key_padding_mask when use_flash_attn is False (in flash_attn.modules.mha.MHA). Thank you.

agshar96 commented 1 week ago

Hi Everyone, Recently I published a paper in ENLSP Workshop@NEURips 2024, to address this problem, the paper can be found here: https://arxiv.org/pdf/2409.15097

I have the code, but its in a private repository currently, as I am still cleaning up the code. If someone wants to access this repo just send a mail to: agnivunoff96@gmail.com

Meanwhile, I realised that pytorch team already implemented a change which pretty much uses same method which I used. (I came up with my method independently for a university project the pytorch blog came around half a month after my university project).

Anyways, TL:DR - pytorch has now enabled custom masking of flash attention, you can find it here: https://pytorch.org/blog/flexattention/ (And, I am sad man, as my method will never be used)