huggingface / distil-whisper

Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate.
MIT License
3.32k stars 238 forks source link

BetterTransformer optimization / flash_attn{_2} #93

Closed thomaschhh closed 3 months ago

thomaschhh commented 3 months ago

I have been trying to replicate the training steps of distil-whisper as described in training/

However, when running the pseudo-labeling step I run into the following error:

ValueError: Transformers now supports natively BetterTransformer optimizations 
(torch.nn.functional.scaled_dot_product_attention) for the model type whisper. As such, there is no need to use 
`model.to_bettertransformers()` or `BetterTransformer.transform(model)` from the Optimum library. Please upgrade to 
transformers>=4.36 and torch>=2.1.1 to use it. Details:

After that, I decided to set --attn_type "flash_attn_2". However, this throws the following error:

/flash_attn/", line 51, in _flash_attn_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: query and key must have the same dtype

Is this a know error or have I been doing something wrong?

guynich commented 3 months ago

This might help:

thomaschhh commented 3 months ago

Looks good for --attn_type "flash_attn" but not for --attn_type "flash_attn_2". In that case I still get the above-mentioned error.

guynich commented 3 months ago

I see: I've only used "flash_attn" and needed this PR.

sanchit-gandhi commented 3 months ago

Fixed in #101! You can now set --attn_implementation to either {"eager", "sdpa", "flash_attn_2"}:

The README has been updated to reflect this change:

sanchit-gandhi commented 3 months ago

Closing as resolved! Feel free to re-open if you continue to encounter issues

thomaschhh commented 2 months ago

I just looked into it again and it seems like there is a mismatch between the help string and the value that is expected in L142, which is flash_attention_2 and not flash_attn_2.