Open MikeynJerry opened 1 year ago
In fact, when you send an attention mask to PyTorch's implementation, flash attention didn't work.
Yes, facing the same issue. @tridao Can you please take a look at this and respond when you are available?
Attention mask isn't supported (either in v1 or v2). I might implement it at some point but there are other priorities now.
I thought masking is supported through flash_attn_varlen_func
I have tested v1.0.7 and v2.0.4. The result turns out that none of them supports attention mask ---
The results of A and B are different.
This paper might be relevant: https://arxiv.org/abs/2306.01160.
There are several related issues:
I believe pytorch 2.1 will have a memory efficient attention implementation that supports arbitrary masks: https://github.com/pytorch/pytorch/issues/96099
@tridao Hello, I plan to add a bias mask in flashattention2. I noticed that in order to integrate the scale and add operations scale_apply_exp2 ,the scale is delayed until after the maximum value is calculated. I plan to support bias mask in the apply_mask_causal function, I think if a bias mask is supported, it seems that ffma optimization in scale_apply_exp2 can be cancelled. Using scale and bias can still benefit from FFMA, do you have any suggestions?
flash_attn/flash_attn_triton.py support bias input you can use bias=-inf
flash_attn/flash_attn_triton.py support bias input you can use bias=-inf
This is a good point but the example itself is not working with pytorch2.0+ (<==triton2.0+) ðŸ˜
Anyone have tips on custom masks with flash attention for training?
(I need this to train encoder-decoder models with variable-length sequences using non-causal masks.)
This came up in a recent article: https://www.yitay.net/blog/training-great-llms-entirely-from-ground-zero-in-the-wilderness
The other striking thing is how little support these codebases have for large scale encoder-decoder training or even prefixLM training. To that end, even flash attention has consistently declined to provide support for prefixLM training (i.e., custom masks) despite reasonable demand on their github issues for whatever reason.
Curious what this would take or if it is still out of scope for the flash attention library?
Really grateful that this exists!! Just posting for visibility in case others have solved this problem :)
Curious what this would take or if it is still out of scope for the flash attention library?
Not out of scope, it's just someone needs to go implement it :D
Understood — thank you!! Will try using the varlen functions for now :)
I was wondering if there has been any updates on this? AlphaFold3 uses a lot of attention pair biasing and it would be tremendously useful to computational biology if flash attention supported attention biasing!
I was wondering if there has been any updates on this? AlphaFold3 uses a lot of attention pair biasing and it would be tremendously useful to computational biology if flash attention supported attention biasing!
Right, we still need someone to implement it.
@tridao Was wondering, what needs to be done for this to be implemented (I'm assuming efficiently? otw it seems quite simple)
I need a similar feature (arbitrary attention masks) but I figured I might take a stab at just implementing it if it still needs to be done.
I've implemented a version of custom masking for FA2 in Triton: https://github.com/alexzhang13/flashattention2-custom-mask
It suffices for my use case, but if something comes up where it's necessary to touch the FA3 code I may re-visit this.
Attention mask isn't supported (either in v1 or v2). I might implement it at some point but there are other priorities now.
Seems like the FlashAttention class does take in a key_padding_mask
argument in its forward method. What would be the difference between this and the attention mask to be implemented? Cc @tridao. Thanks!
As you can see in the code, key_padding_mask just removes elements from keys and values before passing to the flash attention kernel. There's no attention mask passed to the kernel.
As you can see in the code, key_padding_mask just removes elements from keys and values before passing to the flash attention kernel. There's no attention mask passed to the kernel.
Is there any plan to support key_padding_mask
in MHA in v2 ? My understanding is that this was supported in v1 (in flash_attn.flash_attention.FlashMHA
), but in v2, one can only use key_padding_mask
when use_flash_attn
is False
(in flash_attn.modules.mha.MHA
). Thank you.
Hi Everyone, Recently I published a paper in ENLSP Workshop@NEURips 2024, to address this problem, the paper can be found here: https://arxiv.org/pdf/2409.15097
I have the code, but its in a private repository currently, as I am still cleaning up the code. If someone wants to access this repo just send a mail to: agnivunoff96@gmail.com
Meanwhile, I realised that pytorch team already implemented a change which pretty much uses same method which I used. (I came up with my method independently for a university project the pytorch blog came around half a month after my university project).
Anyways, TL:DR - pytorch has now enabled custom masking of flash attention, you can find it here: https://pytorch.org/blog/flexattention/ (And, I am sad man, as my method will never be used)
Is any plan to add attention masking support? PyTorch's version of flash attention v1 included the ability to provide an attention mask in their implementation and it would be very useful to have this feature in v2.