huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.27k stars 26.61k forks source link

Flash Attention on Finetuned Models #29944

Closed RohitMidha23 closed 6 months ago

RohitMidha23 commented 6 months ago

System Info

Who can help?

@sanchit-gandhi

Information

Tasks

Reproduction

  1. Finetuned a model as mentioned in the HuggingFace Blog.
  2. Using InsanelyFastWhisper for predictions with the translate task.
File "/root/.pyenv/versions/3.11.6/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 432, in forward
raise ValueError("WhisperFlashAttention2 attention does not support output_attentions")
ValueError: WhisperFlashAttention2 attention does not support output_attentions

Expected behavior

FlashAttention2 should work with finetuned models.

sanchit-gandhi commented 6 months ago

Hey @RohitMidha23 - could you share a reproducible code snippet for your code? Without it, it's hard to say what the error is here, but it's likely fixed by ensuring output_attentions=False. This is because outputting the attentions is not compatible with Flash Attention 2.

Note for this, you will also need to disable word-level timestamps, since these implicitly set output_attentions=True in the generation code: https://github.com/huggingface/transformers/blob/536ea2aca234fb48c5c69769431d643b0d93b233/src/transformers/models/whisper/generation_whisper.py#L1013-L1016

RohitMidha23 commented 6 months ago

You're right, it was due to word level timestamps. Thanks @sanchit-gandhi

ArmykOliva commented 1 month ago

I need word level timestamps. How can I fix this? I know that it's possible because the insanely fast whisper replit hosting has it and it works there.

marziye-A commented 2 weeks ago

i have the exact same problem and i need the word level timestamps. how shoul i fix it? where should i set output_attentions=False?