juanmc2005 / diart

A python package to build AI-powered real-time audio applications
https://diart.readthedocs.io
MIT License
1.02k stars 87 forks source link

Hallucinations from program #235

Closed hmehdi515 closed 4 months ago

hmehdi515 commented 4 months ago

Hi. I trying to test this program with an audio file and can't seem to get rid of hallucinations, although the transcribing works perfectly by just using

whisper_timestamped test_audio.wav --model small

The correct transcription that is outputted (.srt file):

1
00:00:00,360 --> 00:00:02,160
So, what's new, Mark?

2
00:00:02,980 --> 00:00:04,700
How is your new job going?

3
00:00:05,720 --> 00:00:08,500
To be honest, I can't complain.

4
00:00:09,660 --> 00:00:12,660
I really love the company that I am working for.

5
00:00:14,220 --> 00:00:17,660
My co-workers are all really friendly and helpful.

6
00:00:18,800 --> 00:00:20,820
They really help me feel welcome.

7
00:00:22,160 --> 00:00:24,980
It's a really energetic and fun atmosphere.

But when I run the diart_whisper program (same small model):

 So, what's my word?
 I'm...
 ...Ka-
 ...Ballemash.

 ...Ka- ...Ka- ...Ballemash.
 ...Ballemash.

I am using a server with GPUs and Cuda so I don't believe using a slow CPU is the issue.

I've tried playing around with the chunk size with no success. Not sure if tuning the parameters is the issue either since it works from the command line. Might be something going wrong in the diarization part?

I have attached a .mp4 of the audio file used since I can't attach a .mp3 or .wav file, but you can change the extension to .wav

https://github.com/juanmc2005/diart/assets/170109983/78e41776-2460-42f9-97f1-08a805be2ec6

The code:

import logging
import os
import sys
import traceback
from contextlib import contextmanager

import diart.operators as dops
import numpy as np
import rich
import rx.operators as ops
import whisper_timestamped as whisper
import torch
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.sources import FileAudioSource  # Import FileAudioSource
from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment

def concat(chunks, collar=0.05):
    """
    Concatenate predictions and audio
    given a list of `(diarization, waveform)` pairs
    and merge contiguous single-speaker regions
    with pauses shorter than `collar` seconds.
    """
    first_annotation = chunks[0][0]
    first_waveform = chunks[0][1]
    annotation = Annotation(uri=first_annotation.uri)
    data = []
    for ann, wav in chunks:
        annotation.update(ann)
        data.append(wav.data)
    annotation = annotation.support(collar)
    window = SlidingWindow(
        first_waveform.sliding_window.duration,
        first_waveform.sliding_window.step,
        first_waveform.sliding_window.start,
    )
    data = np.concatenate(data, axis=0)
    return annotation, SlidingWindowFeature(data, window)

def colorize_transcription(transcription):
    """
    Unify a speaker-aware transcription represented as
    a list of `(speaker: int, text: str)` pairs
    into a single text colored by speakers.
    """
    colors = 2 * [
        "bright_red", "bright_blue", "bright_green", "orange3", "deep_pink1",
        "yellow2", "magenta", "cyan", "bright_magenta", "dodger_blue2"
    ]
    result = []
    for speaker, text in transcription:
        if speaker == -1:
            # No speakerfound for this text, use default terminal color
            result.append(text)
        else:
            result.append(f"[{colors[speaker]}]{text}")
    return "\n".join(result)

@contextmanager
def suppress_stdout():
    # Auxiliary function to suppress Whisper logs (it is quite verbose)
    # All credit goes to: https://thesmithfam.org/blog/2012/10/25/temporarily-suppress-console-output-in-python/
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout

class WhisperTranscriber:
    def __init__(self, model="small", device=None):
        self.model = whisper.load_model(model, device=device)
        self._buffer = ""

    def transcribe(self, waveform):
        """Transcribe audio using Whisper"""
        # Pad/trim audio to fit 30 seconds as required by Whisper
        audio = waveform.data.astype("float32").reshape(-1)
        audio = whisper.pad_or_trim(audio)

        # Transcribe the given audio while suppressing logs
        with suppress_stdout():
            transcription = whisper.transcribe(
                self.model,
                audio,
                # We use past transcriptions to condition the model
                initial_prompt=self._buffer,
                vad = False,
                remove_empty_words= True,
                no_speech_threshold=0.4,
                verbose=True  # to avoid progress bar
            )

        return transcription

    def identify_speakers(self, transcription, diarization, time_shift):
        """Iterate over transcription segments to assign speakers"""
        speaker_captions = []
        for segment in transcription["segments"]:

            # Crop diarization to the segment timestamps
            start = time_shift + segment["words"][0]["start"]
            end = time_shift + segment["words"][-1]["end"]
            dia = diarization.crop(Segment(start, end))

            # Assign a speaker to the segment based on diarization
            speakers = dia.labels()
            num_speakers = len(speakers)
            if num_speakers == 0:
                # No speakers were detected
                caption = (-1, segment["text"])
            elif num_speakers == 1:
                # Only one speaker is active in this segment
                spk_id = int(speakers[0].split("speaker")[1])
                caption = (spk_id, segment["text"])
            else:
                # Multiple speakers, select the one that speaks the most
                max_speaker = int(np.argmax([
                    dia.label_duration(spk) for spk in speakers
                ]))
                caption = (max_speaker, segment["text"])
            speaker_captions.append(caption)

        return speaker_captions

    def __call__(self, diarization, waveform):
        # Step 1: Transcribe
        transcription = self.transcribe(waveform)
        # Update transcription buffer
        self._buffer += transcription["text"]
        # The audio may not be the beginning of the conversation
        time_shift = waveform.sliding_window.start
        # Step 2: Assign speakers
        speaker_transcriptions = self.identify_speakers(transcription, diarization, time_shift)
        return speaker_transcriptions

# Suppress whisper-timestamped warnings for a clean output
logging.getLogger("whisper_timestamped").setLevel(logging.ERROR)

# If you have a GPU, you can also set device=torch.device("cuda")
config = SpeakerDiarizationConfig(
    duration=5,
    step=0.5,
    latency="min",
    tau_active=0.5,
    rho_update=0.1,
    delta_new=0.57,
    device=torch.device("cuda")
)

print("-Cuda is available=" , torch.cuda.is_available())
dia = SpeakerDiarization(config)
source = FileAudioSource("test_audio.wav", sample_rate=44100)

# If you have a GPU, you can also set device="cuda"
asr = WhisperTranscriber(model="small", device="cuda")

# Split the stream into 2s chunks for transcription
transcription_duration = 2
# Apply models in batches for better efficiency
batch_size = int(transcription_duration // config.step)

# Chain of operations to apply on the stream of microphone audio
source.stream.pipe(
    # Format audio stream to sliding windows of 5s with a step of 500ms
    dops.rearrange_audio_stream(
        config.duration, config.step, config.sample_rate
    ),
    # Wait until a batch is full
    # The output is a list of audio chunks
    ops.buffer_with_count(count=batch_size),
    # Obtain diarization prediction
    # The output is a list of pairs `(diarization, audio chunk)`
    ops.map(dia),
    # Concatenate 500ms predictions/chunks to form a single 2s chunk
    ops.map(concat),
    # Ignore this chunk if it does not contain speech
    ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0),
    # Obtain speaker-aware transcriptions
    # The output is a list of pairs `(speaker: int, caption: str)`
    ops.starmap(asr),
    # Color transcriptions according to the speaker
    # The output is plain text with color references for rich
    ops.map(colorize_transcription),
).subscribe(
    on_next=rich.print,  # print colored text
    on_error=lambda _: traceback.print_exc()  # print stacktrace if error
)

print("Listening...")
source.read()

Any help is appreciated. Thanks

CsehAbel commented 4 months ago

As has been written previously by the author, play around with the parameters tau_active=0.5, rho_update=0.1, delta_new=0.57. You can find more information about these parameters in the readme and in the issues:

"Increase the value of delta_new, which is essentially the distance threshold (between speaker embedding and closest cluster centroid) to detect new speakers." https://github.com/juanmc2005/diart/discussions/171

Find more about parameters here: https://diart.readthedocs.io/en/latest/autoapi/diart/blocks/clustering/index.html https://diart.readthedocs.io/en/latest/autoapi/diart/blocks/diarization/index.html

Tldr: Please try and lower the tau_activate threshold.

hmehdi515 commented 4 months ago

Update -- decreasing the sample rate of the audio file gave more accurate transcriptions:

 So, what's new, Mark? How is your new job going?
 To be honest...
 I can't complain.
 I really can't complain.
 I really love the company.
 That I am working for.
 My coworker...
 ...are all really...
 friendly and helpful.
 They really help...
 me feel welcome.

I've decided 16000hz is the most accurate, now I'll play with the parameters.

juanmc2005 commented 4 months ago

Hi @hmehdi515, the sample rate supported by both whisper and diart is 16kHz, make sure all your audio is loaded that way or dynamically resampled by diart (check out diart.utils.Resample).

I'm closing the issue as it looks resolved.