Code and models for INTERSPEECH 2024 paper Simul-Whisper: Attention-Guided Streaming Whisper with Truncation Detection
We used the model checkpoints provided by openai-whisper in their open-sourced code and the download links are as follows:
Model | Link |
---|---|
base | Download |
small | Download |
medium | Download |
large-v2 | Download |
We used Python 3.9.16 and PyTorch 2.0.1, but the codes may be compatible with other close Python and PyTorch versions. Other dependencies are listed in requirements.txt
.
A sample code is provided in transcribe.py
. This code includes three main steps:
cfg = AlignAttConfig(
model_path=model_path,
segment_length=segment_length, # chunk length, in seconds
frame_threshold=frame_threshold, # threshold for the attention-guided decoding, in frames
language=language,
buffer_len=buffer_len, # the lengths for the context buffer, in seconds
min_seg_len=min_seg_len, # transcibe only when the context buffer is larger than this threshold. Useful when the segment_length is small
if_ckpt_path=if_ckpt_path,
)
model = PaddedAlignAttWhisper(cfg)
segmented_audio = SegmentWrapper(audio_path=audio_path, segment_length=segment_length)
In order to provide similar CNN features at audio boundaries for streaming and non-streaming inference, we retained a 240-sample buffer. The process of reading audio is shown in the following figure:
hyp_list = []
for seg_id, (seg, is_last) in enumerate(segmented_audio):
new_toks = model.infer(seg, is_last)
hyp_list.append(new_toks)
hyp = torch.cat(hyp_list, dim=0)
hyp = hyp[hyp < DEC_PAD]
hyp = model.tokenizer.decode(hyp)
print(hyp)
model.refresh_segment(complete=True) # refresh the buffer when an utterance is decoded
When transcribing two different utterances, it is necessary to call `model.refresh_segment' to refresh the buffer after the end of the first utterance.
This code is based on openai-whisper and whisper_streaming.