OpenNMT / CTranslate2

Fast inference engine for Transformer models
https://opennmt.net/CTranslate2
MIT License
3.43k stars 306 forks source link

Whisper batch generation is not faster than loops #1648

Open evanarlian opened 8 months ago

evanarlian commented 8 months ago

In CTranslate2 Whisper model, batch generate is not faster than looping one by one. I tried the same thing on Translator model and it shows batching is far superior (a lot faster). I used Whisper small converted to int8 using ct2 tool. Also, GPU memory is higher when batching so I thought CTranslate2 is doing "proper" batching (and not a looping wrapper). Here is my simple Whisper code.

import time

import numpy as np
from ctranslate2 import StorageView
from ctranslate2.models import Whisper
from transformers import WhisperProcessor

def make_prompts(tokenizer, n: int) -> list[list[int]]:
    prompt = tokenizer.convert_tokens_to_ids(
        [
            "<|startoftranscript|>",
            "<|en|>",
            "<|transcribe|>",
        ]
    )
    return [prompt] * n

def loop(
    whisper: Whisper,
    features: list[StorageView],
    prompts: list[list[int]],
):
    for feat, prompt in zip(features, prompts):
        _ = whisper.generate(
            feat,
            [prompt],
            return_scores=True,
            return_no_speech_prob=True,
        )

def batch(
    whisper: Whisper,
    features: StorageView,
    prompts: list[list[int]],
):
    _ = whisper.generate(
        features,
        prompts,
        return_scores=True,
        return_no_speech_prob=True,
    )

def main():
    N_SAMPLES = 8
    N_SEC = 5
    SR = 16000

    # load model and processor
    whisper = Whisper("models/whisper-small", device="cuda")
    processor = WhisperProcessor.from_pretrained("openai/whisper-small")
    tokenizer = processor.tokenizer

    # generate required data
    chunks = np.random.random((N_SAMPLES, N_SEC * SR)).astype(np.float32)
    inputs = processor(chunks, return_tensors="np", sampling_rate=SR)
    mels = inputs["input_features"]
    features_loop = [StorageView.from_array(m[None, :]) for m in mels]
    features_batch = StorageView.from_array(mels)
    prompts = make_prompts(tokenizer, N_SAMPLES)

    # warm up
    print("warming up... ", end="", flush=True)
    for _ in range(7):
        loop(whisper, features_loop, prompts)
        batch(whisper, features_batch, prompts)
    print("done")

    N = 20
    print(f"benchmarking each method for {N} iterations")

    # loop time
    t0 = time.perf_counter()
    for _ in range(N):
        loop(whisper, features_loop, prompts)
    elapsed = time.perf_counter() - t0
    print(f"loop time: {elapsed:0.3f} secs")

    # batch time
    t0 = time.perf_counter()
    for _ in range(N):
        batch(whisper, features_batch, prompts)
    elapsed = time.perf_counter() - t0
    print(f"batch time: {elapsed:0.3f} secs")

main()

When I ran the code on colab (T4 GPU), it outputs:

benchmarking each method for 20 iterations
loop time: 25.311 secs
batch time: 30.086 secs

Is there anything I could do to increase the speed of Whisper batch generation?

BBC-Esq commented 8 months ago

Absolutely! So glad you asked! lol. Ctranslate2 actually does support true batching, but at the C++ level. I'll give you my repository that uses it via the amazing WhisperS2T as well as a direct link to that repository. It's my understanding that there's been a fair amount of discussion on the faster-whisper repository about "batch" processing but that it's not feasible right now due to the state of the repository. On the other hand, faster-whisper has multiple other functionalities that WhisperS2Tdoes not. Keep in mind that my repository is a "little" outdated since I haven't updated it to the most recent WhisperS2T, so consult the upstream with any changes to the API.

However, if you use my repo for sample scripts and keep the versioning the same, you should be fine. I have a lot of experience with WhisperS2T now so feel free to hit me up.

https://github.com/BBC-Esq/WhisperS2T-transcriber

...and the amazing...

https://github.com/shashikg/WhisperS2T

At ~150 stars it flies under the radar...but yet it beats huggingface's "insanely" (hate that name) implementation of Whisper that has thousands of stars. Just goes to show how many stars the stereotypical huggingface repo gets is NOT AT ALL related to the quality of their product, but is more boosted by marketing and networking buddy referrals...Give credit where credit is due is what I say. Try whisperS2T and interested in your feedback!

BBC-Esq commented 8 months ago

BTW, just haven't had the time to update my whispers2t batch repo with this bad boy so stay tuned. ;-)

image

It allows you to specify task, choose any ctranslate2 quantization you want, process all sub-directories recursively, exclude certain file extensions from being processed, change beam size, batch size (courtesy of WhisperS2T), and so on.

BBC-Esq commented 8 months ago

Last post I promise...but here's my analysis of WhisperS2T. I believe my repo uses a traditional "loop" to process using WhisperS2T...but you can also send a batch of information directly to ctranslate2 to process, which is inherently the way WhisperS2T is meant to run. HOWEVER, I opted for the "loop" method because if you send all audio files at once...if ONE fails they all fail, you get ZERO transcriptions. What I found is that if I process, say, 500 audio files, ONE might have corrupted data, thus the entire process triggers an error...

This is supposed to be fixed, however, per this discussion:

https://github.com/shashikg/WhisperS2T/issues/50

Anyways, expand below to see my analysis of the library (not most current version, however):

MY PERSONAL SUMMARY ``` transcribe_with_vad (backends/init) The transcribe_with_vad method in WhisperModel utilizes voice activity detection to transcribe audio files. It corrects the batch parameters for language codes, tasks, and initial prompts using fix_batch_param. The method then processes the audio files in batches through WhisperDataLoader, converting signals to mel spectrograms using LogMelSpectogram in self.preprocessor, and segments the audio based on voice activity. The transcription process is handled by generate_segment_batched, an abstract method to be implemented by subclasses. Progress is tracked using tqdm. fix_batch_param (backends/init) This function is utilized in the WhisperModel class for preparing batch parameters like language codes, tasks, and initial prompts, ensuring they match the number of audio files being processed. Whisperdataloader (data) The WhisperDataLoader class prepares audio data for transcription by segmenting and batching it. It relies on external configurations (SAMPLE_RATE, N_SAMPLES), utility functions (pad_or_trim, audio_batch_generator), and classes (BasicSegmenter, stitch_speech_segments from the same script, torch, numpy) for operation. It handles whether to use voice activity detection (speech_segmenter) or basic segmentation (basic_segmenter) based on input flags. The class segments audio files, optionally merges speech segments to respect maximum speech length, and creates batches that include processed audio signals, prompts, and metadata. The method data_collate_fn assembles these into a format suitable for model input, including padding or trimming audio to uniform length and organizing prompts. It supports dynamic time axis adjustment for batch processing and yields batches ready for transcription processing, tracking progress with tqdm. LogMelSpectrogram (audio.py) The LogMelSpectogram class, a subclass of nn.Module, is designed for converting audio signals into log-mel spectrogram features. It initializes with parameters for the mel-spectrogram calculation (n_mels, n_fft, hop_length, padding) and loads mel filter banks from a predefined file, registering them as a buffer. The class also contains an instance of TorchSTFT for performing short-time Fourier transforms (STFT). It provides a method get_seq_len to adjust sequence lengths based on the hop_length, and a forward method to apply padding (if required), compute the STFT, convert the power spectrogram to a mel scale using the loaded mel filters, apply logarithmic scaling, clip and scale the log-mel spectrograms. This process is essential for preparing audio data for deep learning models in speech processing tasks. It utilizes torch, numpy, F.pad from PyTorch's functional API, and custom configurations (N_MELS, N_FFT, HOP_LENGTH, BASE_PATH) for its operations. WhisperModelCT2 (model.py) WhisperModelCT2. This subclass provides a concrete implementation of the abstract method generate_segment_batched. Here's a brief overview of how WhisperModelCT2 handles the transcription process after reaching generate_segment_batched: 1. Initialization and Configuration: The WhisperModelCT2 constructor initializes the model with configurations for ASR (Automatic Speech Recognition) options, loads the model and tokenizer, and sets up parameters for generating transcriptions. 2. Model and Tokenizer Loading: It loads a translation or transcription model using ctranslate2 based on a path or model name. The tokenizer is also loaded from a specified file. 3. Transcription Process (generate_segment_batched): o Converts the features (audio data processed into a suitable format like log-mel spectrograms) into a format expected by the ctranslate2 model. o Calls the ctranslate2 model's generate method with these features and the specified generation options. This step performs the actual ASR by generating text from the input audio features. o Decodes the output from the model into human-readable text using the loaded tokenizer. o Optionally calculates additional metrics like average log probability of the sequences and no-speech probability if specified in the generation options. o If word timestamps are required, it performs alignment of the generated text with the audio features to produce word-level timestamps. This involves calling align_words, which uses the ctranslate2 model's align function, and then assign_word_timings to assign timings to individual words. 4. Word Timings and Alignment: If the option for word timestamps is enabled, WhisperModelCT2 uses the aligner_model (another instance of a ctranslate2 model) to align the words in the transcribed text with their corresponding positions in the audio. This process generates detailed timing information for each word in the transcription. 5. Returning the Transcription: The method returns a list of dictionaries. Each dictionary contains the transcribed text and, depending on the configuration, may also include average log probabilities, no-speech probabilities, and word-level timing information. model.transcribe_with_vad model = whisper_s2t.load_model(model_identifier=model_identifier, backend='CTranslate2', device=self.device, compute_type=self.quantization, asr_options={'beam_size': self.beam_size}, cpu_threads=os.cpu_count()) MODEL TRANSCRIPTION PROCESS ---------------------------- model.transcribe_with_vad │ ├─ fix_batch_param (Adjustment of Parameters) │ └─ Applies to: lang_codes, tasks, initial_prompts │ └─ WhisperDataLoader (Data Preparation and Loading) │ ├─ Conditional Branching: use_vad flag │ ├─ speech_segmenter (if VAD enabled) │ └─ basic_segmenter (if VAD disabled) │ └─ Optional: stitch_speech_segments (Merge if merge_chunks=True) │ ├─ get_segmented_audio_signal (Audio Segmentation) │ ├─ tokenizer.sot_sequence (Start of Token Sequence) │ └─ tokenizer.encode (Encoding Initial Prompts) │ ├─ data_collate_fn (Data Collation) │ ├─ pad_or_trim (Adjust Audio Signal Lengths) │ └─ External: Torch Operations (Stacking and Tensor Creation) │ └─ Yields Batches to transcribe_with_vad (Transcription) │ ├─ preprocessor (Feature Extraction) │ ├─ TorchSTFT (Spectral Transformation) │ └─ Mel Filter Application & Log Scaling │ └─ generate_segment_batched (in WhisperModelCT2) │ ├─ ctranslate2 Model's `generate` (ASR Generation) │ ├─ Decoding Output to Text │ ├─ Average Log Probability (Optional) │ └─ No-Speech Probability (Optional) │ ├─ align_words (Word Alignment) │ ├─ aligner_model's `align` (Alignment Process) │ └─ assign_word_timings (Timestamp Assignment) │ └─ Structured Transcription Output └─ Includes: Text, avg_logprob, no_speech_prob, word_timestamps └─ Creation of "out" variable (Result Packaging) The difference between the original version of my transcriber and the new version that processes each file separately: First Snippet: Processing Files Individually in a Loop • Sequential Processing: Each audio file is processed one at a time in a while loop, which continues until the file_queue is empty or enumeration_done is set. This approach allows for real-time updates and handling of files as they become available, which can be particularly useful in scenarios where files are being added to the queue dynamically. Second Snippet: Batch Processing Multiple Files • Batch Processing: Processes a list of audio files (audio_files_str) in a single call to transcribe_with_vad. This approach is more efficient if all audio files are available at the start, as it can leverage batch processing optimizations. ```
evanarlian commented 7 months ago

Thank you for telling me about WhisperS2T. I'll take a look later. Currently I'm not using faster-whisper, but instead directly using CTranslate2. The hope is that batching could be used to speed up generation, but right now it does not have the speedups vs just using the standard loop.

BBC-Esq commented 7 months ago

Whisper S2T uses ctransalate 2 directly basically.