huggingface / distil-whisper

Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate.
MIT License
3.32k stars 238 forks source link

BetterTransformer optimization / flash_attn{_2} #93

Closed thomaschhh closed 3 months ago

thomaschhh commented 3 months ago

I have been trying to replicate the training steps of distil-whisper as described in training/README.md:

However, when running the pseudo-labeling step I run into the following error:

ValueError: Transformers now supports natively BetterTransformer optimizations 
(torch.nn.functional.scaled_dot_product_attention) for the model type whisper. As such, there is no need to use 
`model.to_bettertransformers()` or `BetterTransformer.transform(model)` from the Optimum library. Please upgrade to 
transformers>=4.36 and torch>=2.1.1 to use it. Details: https://huggingface.co/docs/transformers/
perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention.

After that, I decided to set --attn_type "flash_attn_2". However, this throws the following error:

/flash_attn/flash_attn_interface.py", line 51, in _flash_attn_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: query and key must have the same dtype

Is this a know error or have I been doing something wrong?

guynich commented 3 months ago

This might help: https://github.com/huggingface/distil-whisper/pull/76

thomaschhh commented 3 months ago

Looks good for --attn_type "flash_attn" but not for --attn_type "flash_attn_2". In that case I still get the above-mentioned error.

guynich commented 3 months ago

I see: I've only used "flash_attn" and needed this PR.

sanchit-gandhi commented 3 months ago

Fixed in #101! You can now set --attn_implementation to either {"eager", "sdpa", "flash_attn_2"}: https://github.com/huggingface/distil-whisper/blob/b948d0269c6f071708c55de4a1e4030cd7726f14/training/run_pseudo_labelling.py#L136-L139

The README has been updated to reflect this change: https://github.com/huggingface/distil-whisper/tree/main/training#1-pseudo-labelling

sanchit-gandhi commented 3 months ago

Closing as resolved! Feel free to re-open if you continue to encounter issues

thomaschhh commented 2 months ago

I just looked into it again and it seems like there is a mismatch between the help string

https://github.com/huggingface/distil-whisper/blob/b6400a3ff1b95e1125f9c2aecba25b97712f9465/training/run_distillation.py#L136 and the value that is expected in L142, which is flash_attention_2 and not flash_attn_2.