shashikg / WhisperS2T

An Optimized Speech-to-Text Pipeline for the Whisper Model Supporting Multiple Inference Engine
MIT License
315 stars 32 forks source link

depricated flag for flash attention 2 with huggingface backend #39

Open BBC-Esq opened 9 months ago

BBC-Esq commented 9 months ago

Hello, just FYI in case you didn't know, apparently Huggingface changed the flag/parameter or what not when trying to specify flash attention 2. Here's the message I got:

The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.

And here's the script I am testing:

import whisper_s2t

model_kwargs = {
    'compute_type': 'float16',
    'asr_options': {
    "beam_size": 5,
    "without_timestamps": True,
    "return_scores": False,
    "return_no_speech_prob": False,
    "use_flash_attention": True,
    "use_better_transformer": False,
},
    'model_identifier': "small",
    'backend': 'HuggingFace',
}

model = whisper_s2t.load_model(**model_kwargs)

files = ['test_audio_flac.flac']
lang_codes = ['en']
tasks = ['transcribe']
initial_prompts = [None]

out = model.transcribe_with_vad(files,
                                lang_codes=lang_codes,
                                tasks=tasks,
                                initial_prompts=initial_prompts,
                                batch_size=20)

transcription = " ".join([_['text'] for _ in out[0]]).strip()

with open('transcription.txt', 'w') as f:
    f.write(transcription)

BTW, I tried using the newer attn_implementation="flash_attention_2" with Bark and COULD NOT get it to work...yet with your program that uses the old use_flash_attention_2=Trueit works. I don't know if it was my script or the different flags....but just be aware in case.

shashikg commented 8 months ago

Thanks will update this in next release.