huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.34k stars 26.86k forks source link

Enabling timestamps changes text/reduces accuracy #30815

Closed jaggzh closed 5 months ago

jaggzh commented 5 months ago

System Info

Who can help?

@sanchit-gandhi @ArthurZucker @younesbelkada

Information

Tasks

Reproduction

  1. With a fine-tuned whisper model
  2. Enable return_timestamps=True in the generate() call
  3. Compare predicted results against a generate() without return_timestamps=True

Expected behavior

The text with and without timestamps "should" match, no? But with timestamps it somehow interferes, changing the text and, in this case, decreasing its accuracy.

This is a fine-tuned model, with a complex voice (patient whispers, breathing on a ventilator), and so far with insufficient data for better training. My point here is that I believe the model will therefore be more susceptible to influences that can deteriorate its recognition. However, my main questions are:

  1. How does the inclusion of timestamps, at generation(..., return_timestamps=True) end up affecting the whole process?
  2. Is there anything that can be done to keep the 'original' (non-timestamp-based) accuracy?

My code (it's a bit of a mess as I experiment):

    for predwav in predwavs:
        aa,sr=librosa.load(predwav, sr=16000)
        sample=aa.astype(np.float64)

        input_features = processor(sample, sampling_rate=sr, return_tensors="pt").input_features 
        # generate token ids

        print("Generating...")
        # Raises error. don't use:
        # pids = model.generate(input_features, return_timestamps=True, return_token_timestamps=True, language='en')
        #    ".../transformers/tokenization_utils.py", line 976, in convert_ids_to_tokens
        #    index = int(index)
        #            ^^^^^^^^^^
        #    ValueError: invalid literal for int() with base 10: 's'

        # Allows timestamps but reduces transcription accuracy:
        pids = model.generate(input_features, return_timestamps=True, language='en')

        # Highest accuracy is without timestamps:
        # pids = model.generate(input_features, language='en')
        print("/Generating.")
        # decode token ids to text
        # print("batch_decode()")
        # transcription = processor.batch_decode(pids, skip_special_tokens=False)
        # print("/batch_decode")
        # print(f"Transcription: {transcription}")
        print("Timestamp info:")
        for pidi, pid in enumerate(pids):
            # timestamps = processor.tokenizer.decode(pid, decode_with_timestamps=True)
            timestamps = processor.tokenizer.decode(pid, output_offset=True)
            pdict = processor.tokenizer.decode(pid, output_offsets=True)
            print(f"Predicted id [{pidi}]: {pdict['text']}")
            print(f"Predicted id [{pidi}]: {pdict['offsets']}")
        import ipdb; ipdb.set_trace(context=16); pass

With generate()'s return_timestamps=True:

Predicted id [0] text: <|startoftranscript|><|en|><|transcribe|> There is a time... ...of a subconscious development. It don't work. Bureau work. The branch, the branch.<|endoftext|>

Predicted id [0] offsets: [{'text': ' There is a time...', 'timestamp': (0.0, 2.6)}, {'text': ' ...of a subconscious development.', 'timestamp': (14.6, 17.6)}, {'text': " It don't work.", 'timestamp': (20.6, 22.6)}, {'text': ' Bureau work.', 'timestamp': (23.400000000000002, 24.400000000000002)}, {'text': ' The branch, the branch.', 'timestamp': (25.6, 27.6)}]

Without generate()'s return_timestamps=True:

Predicted id [0] text: <|startoftranscript|><|en|><|transcribe|><|notimestamps|> there is it time... what is that chin? round one? you know what? the brown strap is... the brown strap is...<|endoftext|>

Predicted id [0] offsets: []

Full code below. (Please don't look at it unless you have to!)

#!/usr/bin/env python3
import os
# This is the directory created by run_speech_recognition_seq2seq.py
whdir_def="..../voice-training-dataset-create/whisper-custom-en"
# whdir_def="whisper-custom-en/checkpoint-1100"

def get_last_checkpoint_dir(dstr):
    # looks in dstr for checkpoint-* directories, picking latest mtime and returning its full rel path
    latest_checkpoint_dir = None
    latest_mtime = -1
    for item in os.listdir(dstr):
        item_path = os.path.join(dstr, item)
        if os.path.isdir(item_path) and item.startswith('checkpoint-'):
            mtime = os.path.getmtime(item_path)
            if mtime > latest_mtime:
                latest_mtime = mtime
                latest_checkpoint_dir = item_path
    return latest_checkpoint_dir

# whdir_def=get_last_checkpoint_dir("whisper-custom-en")
print(f"Whisper dir! {whdir_def}")

predwavs=[
    '..../patient--2024-04-27_06-32-part1.flac',
   ]

use_test_dataset_hf = False
use_test_dataset_moz = False

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
from transformers import HfArgumentParser
import librosa
import numpy as np
import torch.nn.functional as F
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import argparse
import torch

# NOTE!  For loading large datasets, like common-voice, seq2seq's hf mozilla commonvoice loader (at least for v11)
# will try to preprocess the ENTIRE SET, even if you set your splits to small %'s.
# I modified mine to:
     #if training_args.do_train:
     #    raw_datasets["train"] = load_dataset(
     #        data_args.dataset_name,
     #        data_args.dataset_config_name,
     #        #split=data_args.train_split_name,
     #  # THIS LINE HERE AND IN THE .do_eval right below this one
     #        split=f'{data_args.train_split_name}[:1%]',  # Load only the first 1%
     #        cache_dir=model_args.cache_dir,
     #        token=model_args.token,
     #        #verification_mode='all_checks',
     #    )
# AND LOWER, BEFORE prepare_dataset(), slice the dataset (or it'll still preproc everything):
# These 4 lines:
    # if data_args.max_train_samples is not None:
    #     raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
    # if data_args.max_eval_samples is not None:
    #     raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))

    # def prepare_dataset(batch):

@dataclass
class AdditionalArguments:
    output_probabilities: bool = field(default=False, metadata={"help": "Output word probabilities"})

def main():
    global predwavs

    parser = argparse.ArgumentParser(description="Speech-to-Text Prediction")
    parser.add_argument("--whdir", type=str, help="Directory of the Whisper model")
    parser.add_argument("--predwav", type=str, help="Path to the audio file for prediction")
    parser.add_argument("-p", "--output_probabilities", action="store_true", help="Output word probabilities")
    parser.add_argument("-cp", "--cnt_probs", type=int, default=10, help="Number of top candidate tokens to output with their probabilities")

    # Parse arguments
    args = parser.parse_args()

    # Rest of your code
    # load model and processor
    # processor = WhisperProcessor.from_pretrained("openai/whisper-large")
    # model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
    # Load model and processor
    whdir = args.whdir if args.whdir is not None else whdir_def
    if args.predwav:
        predwavs = [args.predwav]

    print(f"Predicting on wave(s):\n{predwavs}")
    processor = WhisperProcessor.from_pretrained(whdir)
    model = WhisperForConditionalGeneration.from_pretrained(whdir)
    model.config.forced_decoder_ids = None

    # load dummy dataset and read audio files
    #if use_test_dataset_hf:
    #    ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
    #elif use_test_dataset_moz:
    #    ds = load_dataset(
    #        "mozilla-foundation/common_voice_11_0",
    #        "en",
    #        #split=data_args.train_split_name,
    #        split=f'train[:15%]',  # Load only the first %
    #        #cache_dir=model_args.cache_dir,
    #        token=True
    #        #verification_mode='all_checks',
    #    )
    #    example = ds[0]["audio"]
    #    sample = example['array']
    #    sr = example['array']
    #else:
    for predwav in predwavs:
        aa,sr=librosa.load(predwav, sr=16000)
        sample=aa.astype(np.float64)

        input_features = processor(sample, sampling_rate=sr, return_tensors="pt").input_features 
        # generate token ids

        print("Generating...")
        # Raises error. don't use:
        # pids = model.generate(input_features, return_timestamps=True, return_token_timestamps=True, language='en')
        #    ".../transformers/tokenization_utils.py", line 976, in convert_ids_to_tokens
        #    index = int(index)
        #            ^^^^^^^^^^
        #    ValueError: invalid literal for int() with base 10: 's'

        # Allows timestamps but reduces transcription accuracy:
        #pids = model.generate(input_features, return_timestamps=True, language='en')

        # Highest accuracy is without timestamps:
        pids = model.generate(input_features, language='en')
        print("/Generating.")
        # decode token ids to text
        # print("batch_decode()")
        # transcription = processor.batch_decode(pids, skip_special_tokens=False)
        # print("/batch_decode")
        # print(f"Transcription: {transcription}")
        print("Timestamp info:")
        import ipdb; ipdb.set_trace(context=16); pass
        for pidi, pid in enumerate(pids):
            # timestamps = processor.tokenizer.decode(pid, decode_with_timestamps=True)
            timestamps = processor.tokenizer.decode(pid, output_offset=True)
            pdict = processor.tokenizer.decode(pid, output_offsets=True)
            print(f"Predicted id [{pidi}] text: {pdict['text']}")
            print(f"Predicted id [{pidi}] offsets: {pdict['offsets']}")
        import sys; sys.exit()
        import ipdb; ipdb.set_trace(context=16); pass

        transcription = processor.batch_decode(pids, skip_special_tokens=True)
        print(f"Transcription: {transcription}")

        if args.output_probabilities:
            # Use generate to handle decoder inputs automatically
            generated_outputs = model.generate(input_features, output_scores=True, return_dict_in_generate=True)
            scores = generated_outputs.scores  # List of tensors of scores for each step

            for stepi, step_scores in enumerate(scores):
                probabilities = F.softmax(step_scores, dim=-1)
                top_probs, top_indices = torch.topk(probabilities, args.cnt_probs, dim=-1)
                print(f"[{stepi}] Step")
                for i in range(args.cnt_probs):
                    token_id = top_indices[0][i].item()
                    word = processor.tokenizer.decode([token_id])
                    prob = top_probs[0][i].item()
                    print(f"  Token {i + 1}: {word} - {prob:.4f}")

if __name__ == '__main__':
    main()
ArthurZucker commented 5 months ago

cc @kamilakesbi as well! @jaggzh we are going to need an audio we can work with together, and if you can reduce the reproducer to a minimal amount of custom code would be great!

sanchit-gandhi commented 5 months ago

Hey @jaggzh - thanks for reporting. This is actually the intended behaviour with Whisper. To understand why, recall that Whisper predicts the distribution over the next token $y{i}$ conditionally over all previous tokens $\boldsymbol{y}\{0:i-1}$:

$$ y{i} \sim P\left(y | \boldsymbol{y}\{0:i-1}\right) $$

When we decode without timestamps, we generate sequences with the following format:

<|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> The cat sat on the mat.<|endoftranscript|>

Note the task token at index 4: the <|notimestamps|> indicates to the model that it should not predict timestamps.

To decode with timestamps, we ensure that the <|notimestamps|> is not generated at position 4, which triggers the model to predict with timestamp tokens:

<|startoftranscript|> <|en|> <|transcribe|> <|0.00|> The cat sat on the mat.<|4.22|><|endoftranscript|>

=> we can see here that the sequence of token ids changes in two ways:

  1. The task tokens are changed (we drop the <|notimestamps|> token from position 4)
  2. We predict timestamp tokens as part of the generated sequence (in this example, <|0.00|> and <|4.22|>). The key here is understanding that these timestamp tokens are predicted in the same way as the text tokens: auto-regressively based on the conditional probability distribution over previous tokens.

Since the sequence of token ids $\boldsymbol{y}_{0:i-1}$ changes, the predictions for token $y_{i}$ also change (by nature of the conditional probability distribution that we predict). Therefore, it's possible that the generations with timestamps differ from those without timestamps.

Generally, what we observe is that enabling timestamps gives less accurate transcriptions for short-form audio, and more accurate for long-form audio (whether you're using the chunked or sequential decoding algorithms).

sanchit-gandhi commented 5 months ago

Closing the issue since it is in-fact the intended behaviour from Whisper, but happy to answer any follow-up questions you have! Feel free to post on this comment thread 🤗

jaggzh commented 5 months ago
2. We predict timestamp tokens as part of the generated sequence (in this example, `<|0.00|>` and `<|4.22|>`). The key here is understanding that these timestamp tokens are predicted in the same way as the text tokens: auto-regressively based on the conditional probability distribution over previous tokens.

Thank you so much for the extremely helpful and detailed explanation! Are our timestamp tokens then initially generated during training, and are they actually {.02f}? (I'm dealing with short-form audio, from maybe .3 to 6s max (and very few samples reach above 1s -- it's for someone with speech issues). If I were to give up accuracy in the timestamp, like .1f, it might help the model have less variation in the timestamp tokens, and an easier time learning higher accuracy [I'm thinking]. The actual main goal of mine is not to get the accuracy of the timestamps -- they can be rough -- but not to damage the transcription accuracy [much] in the process.

Nevertheless, since it's short-form disjoint speech I began working on a project that does some nice automatic breaking up of audio with auto-calibrated silence detection -- and that's a module that operates as a generator function, returning the clip and the time offset, so I can use it in different projects (including my data prep OR prediction code). Thus, with such short utterances, I'm able to then get the timestamp of each clip and that'll be sufficient for my needs.

(It's not on topic, but if anyone's interested (not that they'll see this closed issue))... You can find it here: https://gist.github.com/jaggzh/e9a5b31afc218b8d44fd5ddb976c8c96 (If run directly it'll accept an audio file to test ones settings), but I didn't incorporate arg parsing so one has to modify the code to evaluate them).

It handles evaluating a provided audio file (file only right now.. can't yet use it on a live audio stream). It examines requested seconds of audio (chunk) and, within that small examination windows for each of their max amplitudes. (It considers the lowest of those as the noise floor). It then evaluates the max it heard (discards some (maxamp_discard_frac)), to take a fraction between the floor and that max as the acceptable signal (voice) level.

SS_20240525_023824

The purpose was to automatically adjust, instead of using fixed dB of many solutions I found.

If plotting, it ends up using my non-breaking key module (kbnb) -- that import can just be left out if not using it. Otherwise that's included in the gist, along with bansi.py for some perdy colors also used in the plotting.

In any case, it's also a good example of matplotlib running and updating its window in the bg, non-blocking. :)

jaggzh commented 3 months ago

I have a new idea, since timestamps are useful, and accuracy is useful. If we remove [some] timestamps from the recurrent context, while maintaining them in the output to the user, we might be able to maintain accuracy. I'm not very-well-aware of the caching mechanisms involved, but the idea would be something like this: Input audio: "Hello world" Starting context: <|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> [Model predicts "Hello" token with high accuracy (and is therefore at <|0.00|>)] -> <|startoftranscript|> <|en|> <|transcribe|> Hello [We insert <|0.00|>] Recurrent context: <|startoftranscript|> <|en|> <|transcribe|> <|0.00|>Hello [Model predicts "<|0.04|>"] -> <|startoftranscript|> <|en|> <|transcribe|> Hello<|0.04|> [If it was a timestamp token, we keep it for the 'Context-to-User', but strip it from recurrent context:] Recurrent context: <|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> Hello [Model predicts " world" (high accuracy)] -> <|startoftranscript|> <|en|> <|transcribe|> Hello world Recurrent context: <|startoftranscript|> <|en|> <|transcribe|> <|0.00|>Hello OR: Recurrent context: <|startoftranscript|> <|en|> <|transcribe|> Hello<|0.04|> world (I'm not sure how happy the model will be without the initial 0.00).

Two possible variations:

  1. Alternating stripping and non-stripped timestamps for recurrent context Example run above ^^^
  2. Use "Dynamic sparse timestamps" which will interfere less (hopefully) with the model: With #2, an example might be: Typical timestamp context: "<|0.00|>The<|0.04|> rain<|0.10|> in Spain<|0.16|> falls mainly<|0.26|> on the<|0.34|> plain." Sparse (dynamic/stripped) recurrent timestamp context: "<|0.00|>The rain in Spain<|0.16|> falls mainly" -> " on" -> " the" Sparse (dynamic/stripped) recurrent timestamp context: "<|0.00|>The rain in Spain falls mainly<|0.26|> on the" -> " plain" Sparse (dynamic/stripped) recurrent timestamp context: "<|0.00|>The rain in Spain falls mainly<|0.26|> on the plain."

By using a dynamic stripping, we can choose, each pass, which timestamp tokens we keep, with the idea being that the attention head(s) can match up enough of the audio features to transcription tokens to maintain the next token accuracy. When we expect a timestamp token we can include a prior timestamp closer to the last token.

(We could also attempt to force a timestamp or token prediction, as needed, with prefix_allowed_tokens_fn, for example. But this could either be optional or an experimental part of the algorithm -- or with an adjustable token spacing.)

@sanchit-gandhi @ArthurZucker @younesbelkada

sanchit-gandhi commented 3 months ago

That's correct @jaggzh - the model is trained to predict timestamps to 0.2f precision during training. See page 3 of the whisper paper for details.

Changing the precision of the timestamps is unlikely to get you any improvements in transcription accuracy. In fact, you risk potentially lower timestamp accuracy as you generate since you deviate away from the most probably predictions.

Regarding modifying the decoding algorithm: If you want to be able to predict timestamps at index i, then you need to have predicted timestamps for indices 0:i-1. If you pass in previous ids that don't have timestamps, I'm pretty sure you'll mess up the predictions for index i.

One option to try and get the best of both worlds is what they do in Whisper-X - use Whisper for the transcriptions, but wav2vec2 for the timestamps.

jaggzh commented 3 months ago

I'm so sorry -- I'm referring to the accuracy of transcription being maintained, not timestamp accuracy changed. The idea is to use a pass with the notimestamps token, and likely with a dynamic context (stripping prior timestamps so as not to confuse the model with timestamps in the context), and then passes WITH the request for timestamp, with timestamps included (or possibly partial timestamps included -- this would need to be tested, and may vary based on whether the user is aiming for high resolution (word-based) timestamps or not).