huggingface / distil-whisper

Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate.
MIT License
3.32k stars 238 forks source link

How to use distil-whisper-large-v3-de-kd model from HF? #95

Open Arche151 opened 3 months ago

Arche151 commented 3 months ago

Officially, multi-language support is still not implemented in distil-whisper.

But I noticed, that the esteemed @sanchit-gandhi uploaded a German model for distil-whisper to HuggingFace, called 'distil-whisper-large-v3-de-kd'

How can I use this specific model for transcribing something?

sanchit-gandhi commented 3 months ago

Hey @Arche151 - the "official" Distil-Whisper checkpoints were trained only on English speech recognition data, thus they can only be used for English. However, the training code provided in this repository generalises to all languages: https://github.com/huggingface/distil-whisper/tree/main/training

The checkpoint you've mentioned is trained using exactly this approach on German speech recognition data, giving a model compatible with German audio. You can use it in exactly the same way as the original Distil-Whisper checkpoints:

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "sanchit-gandhi/distil-whisper-large-v3-de-kd"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    torch_dtype=torch_dtype,
    device=device,
)

dataset = load_dataset("facebook/multilingual_librispeech", "german", split="validation", streaming=True)
sample = next(iter(dataset))["audio"]

result = pipe(sample, generate_kwargs={"language": "german", "task": "transcribe"})
print(result["text"])

Print output:

 dann, und als er das sagte, übertrieb er sehr arg, wie alle, die in Italien geliebt haben.
sanchit-gandhi commented 3 months ago

You can even use the recipe for this checkpoint to train a model in a different language: https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-de-kd#training-procedure

Simply swap out the --train_dataset_config_name and --eval_dataset_config_name to the common voice split of your choice: https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1

Arche151 commented 3 months ago

@sanchit-gandhi Thanks a lot for the quick and insightful response!!

Last question: Do you maybe also know, whether I can use the German model via faster-whisper or would it be necessary to convert it with ctranslate2?

Okay, last last question: Do you know the speed difference of transcription via distil-whisper vs. faster-distil-whisper? I coudln't find a comparison.

sanchit-gandhi commented 3 months ago

Hey @Arche151, no problem!

  1. You would indeed need to convert the weights from HF Transformers format to faster-whisper format (CTranslate2). You can use this script for the conversion
  2. Distil-Whisper is an architectural change that leads to a faster model (the model itself is inherently faster). Faster-Whisper is an implementation change (the model is the same, but the code is more efficient). That means that speed gains from Distil-Whisper should carry over to Faster-Whisper (now you have a more efficient model and more efficient code), but this might not necessarily be the compound of the two speed gains combined
Arche151 commented 3 months ago

@sanchit-gandhi Ahh okay, that background information is super useful. Thanks for taking the time to explain!

Thanks for the script! Should've read the repo better 😅

Arche151 commented 3 months ago

@sanchit-gandhi So, I converted the model and used float16 quantization and the quality of the transcription compared to the original large-v3 is really bad :(

A lot of words are transcribed falsely, some words are just not transcribed at all and some words are transcribed twice, so there's duplicates.

In my test script I wrote this: model = WhisperModel(/path/distil-whisper-large-v3-de-kd-ct2", device="cpu", compute_type="int8")

sanchit-gandhi commented 3 months ago

Hey @Arche151 - could you provide a reproducible code snippet for the behaviour you're seeing? Both for the original large-v3 model, and for the distil-whisper one? This should help discern where the divergence is occurring

sanchit-gandhi commented 3 months ago

I'm guessing here that you're using a long audio file (> 30 seconds)? One of the limitations with the current Distil-Whisper training code is that it shifts the distribution of the model to shorter audio lengths. This means it often breaks when used with OpenAI's sequential long form algorithm. I'm going to push some updated training code in the coming days, where this phenomenon is addressed

Once this is ready, we can re-train the model and it should be compatible with other Whisper libraries, like Faster-Whisper

Arche151 commented 3 months ago

@sanchit-gandhi Thanks for getting back to me!

I'm actually using audios that are shorter than 30 seconds.

Here's the complete Python script, that I'm using, when I use distil-whisper-large-v3-de-kd:

import os
import subprocess
from faster_whisper import WhisperModel

audio_file = "/tmp/audio_recording.wav"
recording_state_file = "/tmp/recording_state"

def start_recording():
    subprocess.Popen(["arecord", "-f", "cd", audio_file])
    open(recording_state_file, 'w').close()

def stop_recording():
    subprocess.call(["pkill", "arecord"])
    if os.path.exists(recording_state_file):
        os.remove(recording_state_file)
    transcribe_audio()
    os.remove(audio_file)

def is_recording():
    return os.path.exists(recording_state_file)

def transcribe_audio():
    model = WhisperModel("path/to/distil-whisper-large-v3-de-kd-ct2", device="cpu")
    segments, info = model.transcribe(audio_file)
    transcription = " ".join([segment.text for segment in segments]).strip()
    subprocess.Popen(["xclip", "-selection", "c"], stdin=subprocess.PIPE).communicate(input=transcription.encode())
    # Notify the user that transcription is complete and copied to clipboard
    subprocess.call(["notify-send", "Transcription Complete", "The transcription has been copied to the clipboard."])

def main():
    if is_recording():
        stop_recording()
    else:
        start_recording()

if __name__ == "__main__":
    main()

And here's the same script, but with large-v3 instead, where the transcription works fine:

import os
import subprocess
from faster_whisper import WhisperModel

audio_file = "/tmp/audio_recording.wav"
recording_state_file = "/tmp/recording_state"

def start_recording():
    subprocess.Popen(["arecord", "-f", "cd", audio_file])
    open(recording_state_file, 'w').close()

def stop_recording():
    subprocess.call(["pkill", "arecord"])
    if os.path.exists(recording_state_file):
        os.remove(recording_state_file)
    transcribe_audio()
    os.remove(audio_file)

def is_recording():
    return os.path.exists(recording_state_file)

def transcribe_audio():
    model = WhisperModel("large-v3", device="cpu", compute_type="int8")
    segments, info = model.transcribe(audio_file)
    transcription = " ".join([segment.text for segment in segments]).strip()
    subprocess.Popen(["xclip", "-selection", "c"], stdin=subprocess.PIPE).communicate(input=transcription.encode())
    # Notify the user that transcription is complete and copied to clipboard
    subprocess.call(["notify-send", "Transcription Complete", "The transcription has been copied to the clipboard."])

def main():
    if is_recording():
        stop_recording()
    else:
        start_recording()

if __name__ == "__main__":
    main()
Arche151 commented 2 months ago

@sanchit-gandhi I wanted to ask, whether you had the time to look at my scripts and if the updated training code, that you mentioned, has already been pushed :)