huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.09k stars 27.03k forks source link

Incorrect Whisper long-form decoding timestamps #31942

Closed Robinysh closed 4 months ago

Robinysh commented 4 months ago

System Info

Who can help?

@Narsil @sanchit-gandhi

Information

Tasks

Reproduction

import numpy as np
import json
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset

device = "cuda"
torch_dtype = torch.bfloat16
# model_id = "openai/whisper-large-v3"
model_id = "distil-whisper/distil-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=False, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=None,
    torch_dtype=torch_dtype,
    device=device,
)

dataset = load_dataset(
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
sample = dataset[0]["audio"]
sample = np.concatenate([sample["array"]] * 10)

results = pipe(
    sample,
    return_timestamps=True,
    generate_kwargs={
        "language": "english",
    },
)
print(json.dumps(results, indent=4))

Output

{
    "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome His Gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel. Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel.",
    "chunks": [
        {
            "timestamp": [
                0.0,
                6.5
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                6.5,
                12.5
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                12.5,
                18.24
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome His Gospel."
        },
        {
            "timestamp": [
                18.24,
                24.0
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                24.0,
                29.84
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                0.0,
                4.7
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                5.76,
                10.54
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                11.6,
                16.4
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                17.46,
                22.26
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        },
        {
            "timestamp": [
                23.12,
                28.12
            ],
            "text": " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his gospel."
        }
    ]
}

Expected behavior

Currently the timestamp resets to zero after 30s of audio. I expect the timestamps to increase monotonically.

Robinysh commented 4 months ago

Additionally, the bug does not happen when I add chunk_length_s=30 to pipe(), i.e.

results = pipe(
    sample,
    chunk_length_s=30,
    return_timestamps=True,
    generate_kwargs={
        "language": "english",
    },
)

However this workaround is not applicable to me because I also would like to supply the compression_ratio_threshold argument to generate_kwargs, and that is not supported with short-form transcription.

amyeroberts commented 4 months ago

cc @kamilakesbi

kamilakesbi commented 4 months ago

Thanks for opening this issue @Robinysh!

This is indeed a problem, I'll open an issue to solve it!

Note that we're working on unifying short form and long form generation with PR #30984. Once merged you should be able to use compression_ratio_threshold with short-form transcription :)

kamilakesbi commented 4 months ago

Hi @Robinysh,

You could use this workaround before we properly integrate the solution in Transformers:

import numpy as np
import json
from transformers import AutoProcessor, WhisperForConditionalGeneration
import torch
from datasets import load_dataset

device = "cuda"
torch_dtype = torch.bfloat16

processor = AutoProcessor.from_pretrained("distil-whisper/distil-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v3", torch_dtype=torch.float16)
model = model.to("cuda")

dataset = load_dataset(
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
sample = dataset[0]["audio"]
sample = np.concatenate([sample["array"]] * 10)

inputs = processor(sample, return_tensors="pt", truncation=False, sampling_rate=16_000)
inputs = inputs.to("cuda", torch.float16)

output = model.generate(**inputs, return_timestamps=True, return_segments = True)

result = processor.batch_decode(output['sequences'], skip_special_tokens=True, output_offsets = True)

for i in range(len(result[0]['offsets'])):
   result[0]['offsets'][i]['timestamp'] = (output['segments'][0][i]['start'].item(), output['segments'][0][i]['end'].item())

print(json.dumps(result, indent=4))

Explanation:

When performing long form generation with Whisper, the right utterance level timestamps are returned as output to generate when we specify return_segments = True and return_timestamps=True.

The problem arise at the decoding level: batch_decode currently doesn't support the output to long form generation (segments). When specifying output_offsets, it indeed outputs the wrong timestamps you previously obtained.

One simple solution is to replace the obtained timestamps with the ones stored in output[segments], which I do with these lines:

for i in range(len(result[0]['offsets'])):
   result[0]['offsets'][i]['timestamp'] = (output['segments'][0][i]['start'].item(), output['segments'][0][i]['end'].item())

cc @sanchit-gandhi @ylacombe ( We should integrate this properly in batch_decode and also handle it in the automatic speech recognition pipeline, I'll open a PR for that :) )