m-bain / whisperX

WhisperX: Automatic Speech Recognition with Word-level Timestamps (& Diarization)
BSD 2-Clause "Simplified" License
11.33k stars 1.19k forks source link

How to use a fine-tuned segmentation model for diarization? #840

Open Arche151 opened 1 month ago

Arche151 commented 1 month ago

I have a WhisperX Python script for transcribing meetings, but the speaker diarization for German is really bad, unfortunately.

After some research I came across the fine-tuned German segmentation model diarizers-community/speaker-segmentation-fine-tuned-callhome-deu but I haven't figured out how to get WhisperX to use it.

Here's my Python script:

import os
import sys
import torch  
import whisperx
import ffmpeg

# Hardcoded Hugging Face token
HF_TOKEN = 'xyz'

def convert_to_wav(audio_path):
    output_path = os.path.splitext(audio_path)[0] + ".wav"
    ffmpeg.input(audio_path).output(output_path).run(quiet=True, overwrite_output=True)
    return output_path

def transcribe_audio(audio_path, num_speakers=None):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 16
    compute_type = "float16" if torch.cuda.is_available() else "int8"

    # Load WhisperX model
    print(f"Loading WhisperX model on {device}...")
    model = whisperx.load_model("flozi00/whisper-large-v3-german-ct2", device, compute_type=compute_type)

    # Load and transcribe audio
    print("Loading and transcribing audio...")
    audio = whisperx.load_audio(audio_path)
    result = model.transcribe(audio, batch_size=batch_size)

    # Load alignment model
    print("Loading alignment model...")
    model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
    result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)

    # Load diarization model and assign speaker labels
    print("Loading diarization model...")
    diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
    diarize_segments = diarize_model(audio, min_speakers=num_speakers, max_speakers=num_speakers)
    result = whisperx.assign_word_speakers(diarize_segments, result)

    # Add speaker label if missing
    for segment in result["segments"]:
        if 'speaker' not in segment:
            segment['speaker'] = 'Unknown'

    return result["segments"]

def save_transcription(transcription, audio_path, speaker_mapping):
    output_path = os.path.splitext(audio_path)[0] + ".txt"
    with open(output_path, 'w') as f:
        for segment in transcription:
            speaker = speaker_mapping.get(segment['speaker'], segment['speaker'])
            f.write(f"{speaker}: {segment['text']}\n")
    print(f"Transcription saved to {output_path}")

def display_first_10_lines(transcription):
    for i, segment in enumerate(transcription[:10]):
        print(f"{segment['speaker']}: {segment['text']}")
    print()

def get_speaker_names(unique_speakers):
    speaker_mapping = {}
    for speaker in unique_speakers:
        name = input(f"Enter the name for {speaker}: ")
        speaker_mapping[speaker] = name
    return speaker_mapping

def main():
    audio_path = input("Enter the filename of the audio to be transcribed: ").strip().strip("'")
    if not os.path.isfile(audio_path):
        print(f"Error: The file '{audio_path}' does not exist.")
        sys.exit(1)

    num_speakers = int(input("Enter the number of speakers: "))

    if not audio_path.endswith(".wav"):
        print("Converting audio to WAV format...")
        audio_path = convert_to_wav(audio_path)

    print("Transcribing audio...")
    transcription = transcribe_audio(audio_path, num_speakers)

    print("\nFirst 10 lines of the transcription:")
    display_first_10_lines(transcription)

    unique_speakers = sorted(set(segment['speaker'] for segment in transcription))
    speaker_mapping = get_speaker_names(unique_speakers)

    save_transcription(transcription, audio_path, speaker_mapping)

if __name__ == "__main__":
    main()

I'd greatly appreciate any help!

Dream-gamer commented 4 weeks ago

This worked for me:

from diarizers import SegmentationModel
diarize_model = whisperx.DiarizationPipeline(use_auth_token="Your_hftoken",
                                            device = device)

segmentation_model = SegmentationModel().from_pretrained('diarizers-community/speaker-segmentation-fine-tuned-callhome-deu')

fine_tuned_model = segmentation_model.to_pyannote_model()

diarize_model.model._segmentation.model = fine_tuned_model.to(device)
diarize_segments= diarize_model(audio, num_speakers=2)
Arche151 commented 4 weeks ago

@Dream-gamer

I will try that out, thanks! :) Heads up, you accidentally shared your HF_Token.

Do you maybe also happen to know, how I can use speaker embeddings, that I extracted via reference speakers, in the diarization pipeline? I asked about that in the Pyannote Github, but didn't get a response unfortunately. https://github.com/pyannote/pyannote-audio/issues/1750

Dream-gamer commented 4 weeks ago

Hey thanks for the heads up lol. Do let me know if you face any error in the above code cause I just tried this today and it worked. As for your follow-up question, I haven't really used speaker embeddings, but yeah I am currently working on improving my diarized transcriptions, so will play around will that as well and get back to you if I found anything,

slaesh commented 1 week ago

Thanks @Dream-gamer , works just fine using the tuned model loading!

@Arche151 having the same problem, but using the fine-tuned model doesnt increase the accuracy of diarization that much.. how is it going on your side? any further steps taken to improve it? =)

Dream-gamer commented 1 week ago

Glad to know. Yeah, it didn't noticably increase the accuracy. I have been using LLMs like gemini to parse the generated transcript and get corrected transcript. You can use the prompt like "Here is the speaker separated transcript. Some of the words in the transcript are in the wrong speaker labels. Correct them and give corrected transcript:" This has given me much better results for hindi-english transcript. Also, if this didn't improve performace, you can use gemini for diarization itself in google's ai studio: https://aistudio.google.com/app/prompts/new_chat?pli=1 Just upload the audio file, give some context in the prompt, for.ex number of speakers, language in the audio and ask it to generate speaker separated transcript.

slaesh commented 1 week ago

thanks! :) yeah I thought also about doing similar stuff for the transcription itself. I am doing stuff offline, so gemini is not a thing ;D the transcription itself is not that big issue for me, its more the diarization itself. if there are questions and some1 else answers quickly with just 1 or 2 words.. it most of the time not really seperated and both the question and answer belong to the same speaker. weirdly sometimes even the question will count to the answer-er :D