Closed Junyoungpark closed 8 months ago
Hi @Junyoungpark
Due to the fact the returned attention might not be correct (e.g.: https://github.com/Dao-AILab/flash-attention/blob/61a777247900f6c2a37376f3ffd7134385fdc95c/flash_attn/flash_attn_interface.py#L668) unfortunately we had to force-disable output_attention
to False
for all FA-2 models. We can consider to enable it once we have gurantees on FA-2 side that the returned output attentions are correct
Hi @younesbelkada.
Thanks for the detailed reply. I understand the current situation. Thanks 👍
Thank you @Junyoungpark !
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hi @younesbelkada
I am facing the same issue: I need to return the attention maps using Flash Attention 2. So, it would be appreciated if you can clarify these points:
Thanks in advance!
Hi all,
I’ve discovered that Mistral models with flash attention cannot return attention_weights due to a reference error. I anticipate that we can address this issue by passing
return_attn_probs=True
to the flash attention API, but there’s still some uncertainty. It appears thatflash_attn_func
can return theattn_weights
, although it’s worth noting that the output weights may not be entirely correct, according to the official API doc.You can find the relevant code snippet in the Mistral modeling file here. https://github.com/huggingface/transformers/blob/1c31b7aa3bb4e7ef24c77596d2a76f45a770159f/src/transformers/models/mistral/modeling_mistral.py#L468