backspacetg / simul_whisper

Code for our INTERSPEECH paper Simul-Whisper: Attention-Guided Streaming Whisper with Truncation Detection
37 stars 3 forks source link

Dimension bug #1

Closed tjongsma closed 1 week ago

tjongsma commented 2 weeks ago

Hi there,

Using the following code (following the example)

import os
import sys
import torch

sys.path.append(os.path.dirname(__file__))
from simul_whisper.transcriber.config import AlignAttConfig
from simul_whisper.transcriber.segment_loader import SegmentWrapper
from simul_whisper.transcriber.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
base_dir = os.path.dirname(os.path.abspath(sys.argv[0]))

model_path = os.path.join(base_dir, "medium.pt")
if_ckpt_path = os.path.join(base_dir, "cif_models","medium.pt") # align with the whisper model. e.g., using small.pt for whisper small

segment_length = 1.0 # chunk length, in seconds
frame_threshold = 12 # threshold for the attention-guided decoding, in frames
buffer_len = 20 # the lengths for the context buffer, in seconds
min_seg_len = 0.0 # transcibe only when the context buffer is larger than this threshold. Useful when the segment_length is small
language = "en"

audio_path = os.path.join(base_dir, "diarization_voorbeeld_16khz.wav")

if __name__ == "__main__":

    cfg = AlignAttConfig(
        model_path=model_path, 
        segment_length=segment_length,
        frame_threshold=frame_threshold,
        language=language,
        buffer_len=buffer_len, 
        min_seg_len=min_seg_len,
        if_ckpt_path=if_ckpt_path,
    )

    model = PaddedAlignAttWhisper(cfg)
    segmented_audio = SegmentWrapper(audio_path=audio_path, segment_length=segment_length)

    hyp_list = []
    for seg_id, (seg, is_last) in enumerate(segmented_audio):
        new_toks = model.infer(seg, is_last)
        hyp_list.append(new_toks)
        hyp = torch.cat(hyp_list, dim=0)
        hyp = hyp[hyp < DEC_PAD]
        hyp = model.tokenizer.decode(hyp)
        print(hyp)

    model.refresh_segment(complete=True) # refresh the buffer when an utterance is decoded

Gives me the error

Traceback (most recent call last): File "c:\Users\tjong\Desktop\Audio_transcription_dev\whisper_streaming\simul_whisper\simul_whisper-main\transcribe.py", line 39, in new_toks = model.infer(seg, is_last) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "c:\Users\tjong\Desktop\Audio_transcription_dev\whisper_streaming\simul_whisper\simul_whisper-main\simul_whisper\transcriber\simul_whisper.py", line 199, in infer
encoder_feature = self.model.encoder(mel) ^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\tjong\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\tjong\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "c:\Users\tjong\Desktop\Audio_transcription_dev\whisper_streaming\simul_whisper\simul_whisper-main\simul_whisper\whisper\model.py", line 166, in forward x = F.gelu(self.conv1(x)) ^^^^^^^^^^^^^ File "C:\Users\tjong\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\tjong\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\tjong\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\conv.py", line 310, in forward return self._conv_forward(input, self.weight, self.bias) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\tjong\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\conv.py", line 306, in _conv_forward return F.conv1d(input, weight, bias, self.stride, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: Expected 2D (unbatched) or 3D (batched) input to conv1d, but got input of size: [1, 2, 80, 3000]

Any idea what could be the issue? Thanks in advance!

backspacetg commented 2 weeks ago

Thank you for your attention! It seems that the input audio has 2 channels, and converting the audio to mono channel in advance should solve this problem. We will also update the code to deal with this situation.

tjongsma commented 1 week ago

Ah thanks for the prompt reply! That should be easy enough to fix, will get on it :)