juanmc2005 / diart

A python package to build AI-powered real-time audio applications
https://diart.readthedocs.io
MIT License
1.1k stars 90 forks source link

Words repeated in whisper transcription with initial_prompt #165

Closed ahmedmoorsy closed 5 months ago

ahmedmoorsy commented 1 year ago

Hello all,

Thank you for doing this great work! I just updated this code to use faster whisper and I facing repeated words issue when I use initial_prompt param in the transcription method. the issue happened when I end my talk in some specific word something like Okay.

The issue:

  Okay. Okay. Okay. Okay. Okay.
 Okay. Okay.
 Okay. Okay. Okay. Okay. Okay.
 Okay. Okay. Okay. Okay. Okay.
 Yeah.
 Okay. Okay. Okay. Okay. Okay.
 Okay. Okay. Okay. Okay. Okay.
 Okay. Okay. Okay. Okay. Okay.
 Okay. Okay. Okay. Okay. Okay.
 Okay. Okay. Okay. Okay. Okay.

The code:

import logging
import os
import sys
import traceback
from contextlib import contextmanager

import torch
import whisper
import diart.operators as dops
import numpy as np
import rich
import rx.operators as ops
# import whisper_timestamped as whisper
from faster_whisper import WhisperModel
from diart import OnlineSpeakerDiarization, PipelineConfig
from diart.sources import WebSocketAudioSource
from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment

def concat(chunks, collar=0.05):
    """
    Concatenate predictions and audio
    given a list of `(diarization, waveform)` pairs
    and merge contiguous single-speaker regions
    with pauses shorter than `collar` seconds.
    """
    first_annotation = chunks[0][0]
    first_waveform = chunks[0][1]
    annotation = Annotation(uri=first_annotation.uri)
    data = []
    for ann, wav in chunks:
        annotation.update(ann)
        data.append(wav.data)
    annotation = annotation.support(collar)
    window = SlidingWindow(
        first_waveform.sliding_window.duration,
        first_waveform.sliding_window.step,
        first_waveform.sliding_window.start,
    )
    data = np.concatenate(data, axis=0)
    return annotation, SlidingWindowFeature(data, window)

def colorize_transcription(transcription):
    """
    Unify a speaker-aware transcription represented as
    a list of `(speaker: int, text: str)` pairs
    into a single text colored by speakers.
    """
    colors = 2 * [
        "bright_red", "bright_blue", "bright_green", "orange3", "deep_pink1",
        "yellow2", "magenta", "cyan", "bright_magenta", "dodger_blue2"
    ]
    result = []
    for speaker, text in transcription:
        if speaker == -1:
            # No speakerfound for this text, use default terminal color
            result.append(text)
        else:
            result.append(f"[{colors[speaker]}]{text}")
    return "\n".join(result)

@contextmanager
def suppress_stdout():
    # Auxiliary function to suppress Whisper logs (it is quite verbose)
    # All credit goes to: https://thesmithfam.org/blog/2012/10/25/temporarily-suppress-console-output-in-python/
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout

class WhisperTranscriber:
    def __init__(self, model="small", device=None):
        self.model = WhisperModel(model, device="cuda", compute_type="float16")
        #self.model = whisper.load_model(model, device=device)
        self._buffer = ""

    def transcribe(self, waveform):
        """Transcribe audio using Whisper"""
        # Pad/trim audio to fit 30 seconds as required by Whisper
        audio = waveform.data.astype("float32").reshape(-1)
        audio = whisper.pad_or_trim(audio)

        # Transcribe the given audio while suppressing logs
        with suppress_stdout():
            segments, _ = self.model.transcribe(
                audio=audio,
                language='en',
                # We use past transcriptions to condition the model
                #initial_prompt=self._buffer,
                word_timestamps=True
            )

        return segments

    def identify_speakers(self, segments, diarization, time_shift):
        """Iterate over transcription segments to assign speakers"""
        speaker_captions = []
        text = ""
        for segment in segments:
            text += segment.text
            # Crop diarization to the segment timestamps
            start = time_shift + segment.words[0].start
            end = time_shift + segment.words[-1].end
            dia = diarization.crop(Segment(start, end))

            # Assign a speaker to the segment based on diarization
            speakers = dia.labels()
            num_speakers = len(speakers)
            if num_speakers == 0:
                # No speakers were detected
                caption = (-1, segment.text)
            elif num_speakers == 1:
                # Only one speaker is active in this segment
                spk_id = int(speakers[0].split("speaker")[1])
                caption = (spk_id, segment.text)
            else:
                # Multiple speakers, select the one that speaks the most
                max_speaker = int(np.argmax([
                    dia.label_duration(spk) for spk in speakers
                ]))
                caption = (max_speaker, segment.text)
            speaker_captions.append(caption)

        return speaker_captions, text

    def __call__(self, diarization, waveform):
        # Step 1: Transcribe
        segments = self.transcribe(waveform)

        # The audio may not be the beginning of the conversation
        time_shift = waveform.sliding_window.start
        # Step 2: Assign speakers
        speaker_transcriptions, text = self.identify_speakers(segments, diarization, time_shift)

        # Update transcription buffer
        self._buffer += text
        return speaker_transcriptions

# Suppress whisper-timestamped warnings for a clean output
logging.getLogger("whisper_timestamped").setLevel(logging.ERROR)

# If you have a GPU, you can also set device=torch.device("cuda")
config = PipelineConfig(
    duration=5,
    step=0.5,
    latency="min",
    tau_active=0.5,
    rho_update=0.1,
    delta_new=0.57,
    device=torch.device("cpu")

)
dia = OnlineSpeakerDiarization(config)
source = WebSocketAudioSource(config.sample_rate, "0.0.0.0", 8081)

# If you have a GPU, you can also set device="cuda"
asr = WhisperTranscriber(model="base")

# Split the stream into 2s chunks for transcription
transcription_duration = 2
# Apply models in batches for better efficiency
batch_size = int(transcription_duration // config.step)

# Chain of operations to apply on the stream of microphone audio
source.stream.pipe(
    # Format audio stream to sliding windows of 5s with a step of 500ms
    dops.rearrange_audio_stream(
        config.duration, config.step, config.sample_rate
    ),
    # Wait until a batch is full
    # The output is a list of audio chunks
    ops.buffer_with_count(count=batch_size),
    # Obtain diarization prediction
    # The output is a list of pairs `(diarization, audio chunk)`
    ops.map(dia),
    # Concatenate 500ms predictions/chunks to form a single 2s chunk
    ops.map(concat),
    # Ignore this chunk if it does not contain speech
    ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0),
    # Obtain speaker-aware transcriptions
    # The output is a list of pairs `(speaker: int, caption: str)`
    ops.starmap(asr),
    # Color transcriptions according to the speaker
    # The output is plain text with color references for rich
    ops.map(colorize_transcription),
).subscribe(
    on_next=rich.print,  # print colored text
    on_error=lambda _: traceback.print_exc()  # print stacktrace if error
)

print("Listening...")
source.read()

Any help ?

brian-j-connolly-aero commented 1 year ago

Great code!!! My guess is the cpu is too slow. I ran into that issue with whisper-realtime before switching over to cuda. Works fine once I switched to cuda and my mic as audio source.

shanky100 commented 1 year ago

Do we have any solution for this problem when using CPU?

shanky100 commented 1 year ago

Hello all,

Thank you for doing this great work! I just updated this code to use faster whisper and I facing repeated words issue when I use initial_prompt param in the transcription method. the issue happened when I end my talk in some specific word something like Okay.

The issue:

  Okay. Okay. Okay. Okay. Okay.
 Okay. Okay.
 Okay. Okay. Okay. Okay. Okay.
 Okay. Okay. Okay. Okay. Okay.
 Yeah.
 Okay. Okay. Okay. Okay. Okay.
 Okay. Okay. Okay. Okay. Okay.
 Okay. Okay. Okay. Okay. Okay.
 Okay. Okay. Okay. Okay. Okay.
 Okay. Okay. Okay. Okay. Okay.

The code:

import logging
import os
import sys
import traceback
from contextlib import contextmanager

import torch
import whisper
import diart.operators as dops
import numpy as np
import rich
import rx.operators as ops
# import whisper_timestamped as whisper
from faster_whisper import WhisperModel
from diart import OnlineSpeakerDiarization, PipelineConfig
from diart.sources import WebSocketAudioSource
from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment

def concat(chunks, collar=0.05):
    """
    Concatenate predictions and audio
    given a list of `(diarization, waveform)` pairs
    and merge contiguous single-speaker regions
    with pauses shorter than `collar` seconds.
    """
    first_annotation = chunks[0][0]
    first_waveform = chunks[0][1]
    annotation = Annotation(uri=first_annotation.uri)
    data = []
    for ann, wav in chunks:
        annotation.update(ann)
        data.append(wav.data)
    annotation = annotation.support(collar)
    window = SlidingWindow(
        first_waveform.sliding_window.duration,
        first_waveform.sliding_window.step,
        first_waveform.sliding_window.start,
    )
    data = np.concatenate(data, axis=0)
    return annotation, SlidingWindowFeature(data, window)

def colorize_transcription(transcription):
    """
    Unify a speaker-aware transcription represented as
    a list of `(speaker: int, text: str)` pairs
    into a single text colored by speakers.
    """
    colors = 2 * [
        "bright_red", "bright_blue", "bright_green", "orange3", "deep_pink1",
        "yellow2", "magenta", "cyan", "bright_magenta", "dodger_blue2"
    ]
    result = []
    for speaker, text in transcription:
        if speaker == -1:
            # No speakerfound for this text, use default terminal color
            result.append(text)
        else:
            result.append(f"[{colors[speaker]}]{text}")
    return "\n".join(result)

@contextmanager
def suppress_stdout():
    # Auxiliary function to suppress Whisper logs (it is quite verbose)
    # All credit goes to: https://thesmithfam.org/blog/2012/10/25/temporarily-suppress-console-output-in-python/
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout

class WhisperTranscriber:
    def __init__(self, model="small", device=None):
        self.model = WhisperModel(model, device="cuda", compute_type="float16")
        #self.model = whisper.load_model(model, device=device)
        self._buffer = ""

    def transcribe(self, waveform):
        """Transcribe audio using Whisper"""
        # Pad/trim audio to fit 30 seconds as required by Whisper
        audio = waveform.data.astype("float32").reshape(-1)
        audio = whisper.pad_or_trim(audio)

        # Transcribe the given audio while suppressing logs
        with suppress_stdout():
            segments, _ = self.model.transcribe(
                audio=audio,
                language='en',
                # We use past transcriptions to condition the model
                #initial_prompt=self._buffer,
                word_timestamps=True
            )

        return segments

    def identify_speakers(self, segments, diarization, time_shift):
        """Iterate over transcription segments to assign speakers"""
        speaker_captions = []
        text = ""
        for segment in segments:
            text += segment.text
            # Crop diarization to the segment timestamps
            start = time_shift + segment.words[0].start
            end = time_shift + segment.words[-1].end
            dia = diarization.crop(Segment(start, end))

            # Assign a speaker to the segment based on diarization
            speakers = dia.labels()
            num_speakers = len(speakers)
            if num_speakers == 0:
                # No speakers were detected
                caption = (-1, segment.text)
            elif num_speakers == 1:
                # Only one speaker is active in this segment
                spk_id = int(speakers[0].split("speaker")[1])
                caption = (spk_id, segment.text)
            else:
                # Multiple speakers, select the one that speaks the most
                max_speaker = int(np.argmax([
                    dia.label_duration(spk) for spk in speakers
                ]))
                caption = (max_speaker, segment.text)
            speaker_captions.append(caption)

        return speaker_captions, text

    def __call__(self, diarization, waveform):
        # Step 1: Transcribe
        segments = self.transcribe(waveform)

        # The audio may not be the beginning of the conversation
        time_shift = waveform.sliding_window.start
        # Step 2: Assign speakers
        speaker_transcriptions, text = self.identify_speakers(segments, diarization, time_shift)

        # Update transcription buffer
        self._buffer += text
        return speaker_transcriptions

# Suppress whisper-timestamped warnings for a clean output
logging.getLogger("whisper_timestamped").setLevel(logging.ERROR)

# If you have a GPU, you can also set device=torch.device("cuda")
config = PipelineConfig(
    duration=5,
    step=0.5,
    latency="min",
    tau_active=0.5,
    rho_update=0.1,
    delta_new=0.57,
    device=torch.device("cpu")

)
dia = OnlineSpeakerDiarization(config)
source = WebSocketAudioSource(config.sample_rate, "0.0.0.0", 8081)

# If you have a GPU, you can also set device="cuda"
asr = WhisperTranscriber(model="base")

# Split the stream into 2s chunks for transcription
transcription_duration = 2
# Apply models in batches for better efficiency
batch_size = int(transcription_duration // config.step)

# Chain of operations to apply on the stream of microphone audio
source.stream.pipe(
    # Format audio stream to sliding windows of 5s with a step of 500ms
    dops.rearrange_audio_stream(
        config.duration, config.step, config.sample_rate
    ),
    # Wait until a batch is full
    # The output is a list of audio chunks
    ops.buffer_with_count(count=batch_size),
    # Obtain diarization prediction
    # The output is a list of pairs `(diarization, audio chunk)`
    ops.map(dia),
    # Concatenate 500ms predictions/chunks to form a single 2s chunk
    ops.map(concat),
    # Ignore this chunk if it does not contain speech
    ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0),
    # Obtain speaker-aware transcriptions
    # The output is a list of pairs `(speaker: int, caption: str)`
    ops.starmap(asr),
    # Color transcriptions according to the speaker
    # The output is plain text with color references for rich
    ops.map(colorize_transcription),
).subscribe(
    on_next=rich.print,  # print colored text
    on_error=lambda _: traceback.print_exc()  # print stacktrace if error
)

print("Listening...")
source.read()

Any help ?

Hi, I tried using the code provided by you, but it is not printing any result other than Listening on the terminal. I think that my mic is also not getting enabled also at the same time when I am running the code.

Any possible solution for the same.

shanky100 commented 1 year ago

Hi, do we have any solution for this issue, as I also after some secs of recording get the repeated results. I am using device as "CPU" image

juanmc2005 commented 1 year ago

Hello, the issue with repeating words could be caused by many issues:

shanky100 commented 1 year ago

Hello, the issue with repeating words could be caused by many issues:

  • CPU is taking too long, so reading from the microphone is being interrupted by Whisper, as pointed out by @brian-j-connolly-aero. If you don't have a GPU this could be solved by using an optimized whisper version (whisper.cpp, faster-whisper, etc) and/or running whisper in a separate process
  • Whisper hallucinates because of short chunks, noise, music, etc. You could try playing with the decoding configuration (see for example the compression threshold). Another option is to buffer audio to give whisper a bigger context and then merge/update transcriptions. This should reduce the chances of weird behavior

I am currently passing the audio file directly and not using the microphone as source. Does this account in any way for this strange behaviour of repeating the words?

juanmc2005 commented 1 year ago

@shanky100 most probably not. Make sure to check if Whisper hallucinates when transcribing the entire file at once (instead of streaming it). It's possible that the chunks are too short. In that case I suggest you try increasing the ASR chunk size or buffering the audio.

You can also try removing the text conditioning. If Whisper hallucinates at the beginning and you keep conditioning it on the hallucination, you may be seeing a snowballing effect