backspacetg / simul_whisper

Code for our INTERSPEECH paper Simul-Whisper: Attention-Guided Streaming Whisper with Truncation Detection
47 stars 4 forks source link

Simul-Whisper

Code and models for INTERSPEECH 2024 paper Simul-Whisper: Attention-Guided Streaming Whisper with Truncation Detection

Setups

Download openai-whisper model checkpoints

We used the model checkpoints provided by openai-whisper in their open-sourced code and the download links are as follows:

Model Link
base Download
small Download
medium Download
large-v2 Download

Install dependencies

We used Python 3.9.16 and PyTorch 2.0.1, but the codes may be compatible with other close Python and PyTorch versions. Other dependencies are listed in requirements.txt.

Python Usage

A sample code is provided in transcribe.py. This code includes three main steps:

1. Configurations

    cfg = AlignAttConfig(
        model_path=model_path, 
        segment_length=segment_length, # chunk length, in seconds
        frame_threshold=frame_threshold, # threshold for the attention-guided decoding, in frames
        language=language,
        buffer_len=buffer_len, # the lengths for the context buffer, in seconds
        min_seg_len=min_seg_len, # transcibe only when the context buffer is larger than this threshold. Useful when the segment_length is small
        if_ckpt_path=if_ckpt_path,
    )

2. Model Initialization

    model = PaddedAlignAttWhisper(cfg)
    segmented_audio = SegmentWrapper(audio_path=audio_path, segment_length=segment_length)

In order to provide similar CNN features at audio boundaries for streaming and non-streaming inference, we retained a 240-sample buffer. The process of reading audio is shown in the following figure:

example

3. Transcribe the segments

    hyp_list = []
    for seg_id, (seg, is_last) in enumerate(segmented_audio):
        new_toks = model.infer(seg, is_last)
        hyp_list.append(new_toks)
        hyp = torch.cat(hyp_list, dim=0)
        hyp = hyp[hyp < DEC_PAD]
        hyp = model.tokenizer.decode(hyp)
        print(hyp)

    model.refresh_segment(complete=True) # refresh the buffer when an utterance is decoded

When transcribing two different utterances, it is necessary to call `model.refresh_segment' to refresh the buffer after the end of the first utterance.

Reference Repositories

This code is based on openai-whisper and whisper_streaming.