huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.33k stars 26.86k forks source link

Whisper no_speech_threshold not applied when chunking input #29595

Closed stri8ed closed 1 month ago

stri8ed commented 7 months ago

Feature request

The Whisper pipeline accepts a chunk_length_s parameter, which chunks the input so it can be used for batch inference. There is also a no_speech_threshold param, which can be used to filter out silence, which helps with reducing halucinations. The problem is, for no_speech_threshold to be applied the output must be "long". And in the case of chunked input, every segment is considered short, even though the full input is long. This means the no_speech_threshold cant be applied for chunked input.

When attempting, It gives this error for each batch: Audio input consists of only 3000. Short-form transcription is activated.no_speech_threshold is set to 0.2, but will be ignored.

It should be possible to keep the no speech threshold enabled for long chunked inputs.

Motivation

Chunking brings large performance improvements for long-form transcription, but this benefit is negated if there is no way to suppress the silence segments, which most long-form audio will no doubt contain.

Your contribution

I could implement a change which removes the is_shortform check when the input is chunked, though I'm unsure if and why this will break things.

amyeroberts commented 7 months ago

cc @sanchit-gandhi @ylacombe

stri8ed commented 7 months ago

I was able to get around this, by using the Vad implementation from faster-whisper.

from faster_whisper.vad import get_speech_timestamps, VadOptions, collect_chunks, SpeechTimestampsMap

audio_data = load_audio(audio_file)
speech_chunks = get_speech_timestamps(audio_data, VadOptions())
audio_without_silence = collect_chunks(audio_data, speech_chunks)

prediction = pipe(audio_without_silence , ...)

chunks = prediction["chunks"]
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
segments = []
    for chunk in chunks:
        start, end = chunk["timestamp"]
        segments.append({
            "start": ts_map.get_original_time(start),
            "end": ts_map.get_original_time(end),
            "text": chunk["text"]
        })
Kimahriman commented 7 months ago

Related to https://github.com/huggingface/transformers/issues/29508, not sure why there are different implementations for long and short audio. I'm playing around with combining them

amyeroberts commented 6 months ago

Gentle ping @sanchit-gandhi @ylacombe

amyeroberts commented 5 months ago

Another ping @ylacombe @sanchit-gandhi @kashif

sanchit-gandhi commented 5 months ago

Indeed, as @Kimahriman mentioned in #29508 there should be no distinction between the short and long-form algorithms. This issue will be fixed when #29508 is fixed, by merging the short/long-form generation logic together.

amyeroberts commented 2 months ago

Is this fixed @kamilakesbi @sanchit-gandhi now #29508 is merged?

amyeroberts commented 1 month ago

cc @ylacombe

ylacombe commented 1 month ago

Hey @amyeroberts, this is indeed fixed since #29508.

@stri8ed, note that you also have to specify the float logprob_threshold and the list temperature, as indicated in the docs !

e.g:

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
import torch

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")

librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = librispeech_dummy[0]["audio"]

input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features

outputs = model.generate(
    input_features, output_scores=True, return_dict_in_generate=True, max_new_tokens=128, no_speech_threshold=0.2, logprob_threshold=-1.0,
    temperature=(0.2, 0.8)
)

transition_scores = model.compute_transition_scores(
    outputs.sequences, outputs.scores, normalize_logits=True
)

pred_text = processor.batch_decode(outputs.sequences, skip_special_tokens=True)
pred_language = processor.batch_decode(outputs.sequences[:, 1:2], skip_special_tokens=False)
lang_prob = torch.exp(transition_scores[:, 0])

print(pred_text)
print(pred_language)
print(lang_prob)