huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.08k stars 27.03k forks source link

Mistral with flash attention cannot return `attention_weights` #28903

Closed Junyoungpark closed 8 months ago

Junyoungpark commented 9 months ago

Hi all,

I’ve discovered that Mistral models with flash attention cannot return attention_weights due to a reference error. I anticipate that we can address this issue by passing return_attn_probs=True to the flash attention API, but there’s still some uncertainty. It appears that flash_attn_func can return the attn_weights, although it’s worth noting that the output weights may not be entirely correct, according to the official API doc.

You can find the relevant code snippet in the Mistral modeling file here. https://github.com/huggingface/transformers/blob/1c31b7aa3bb4e7ef24c77596d2a76f45a770159f/src/transformers/models/mistral/modeling_mistral.py#L468

younesbelkada commented 9 months ago

Hi @Junyoungpark Due to the fact the returned attention might not be correct (e.g.: https://github.com/Dao-AILab/flash-attention/blob/61a777247900f6c2a37376f3ffd7134385fdc95c/flash_attn/flash_attn_interface.py#L668) unfortunately we had to force-disable output_attention to False for all FA-2 models. We can consider to enable it once we have gurantees on FA-2 side that the returned output attentions are correct

Junyoungpark commented 9 months ago

Hi @younesbelkada.

Thanks for the detailed reply. I understand the current situation. Thanks 👍

younesbelkada commented 9 months ago

Thank you @Junyoungpark !

github-actions[bot] commented 8 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

eslambakr commented 1 month ago

Hi @younesbelkada

I am facing the same issue: I need to return the attention maps using Flash Attention 2. So, it would be appreciated if you can clarify these points:

Thanks in advance!