modelscope / FunASR

A Fundamental End-to-End Speech Recognition Toolkit and Open Source SOTA Pretrained Models, Supporting Speech Recognition, Voice Activity Detection, Text Post-processing etc.
https://www.funasr.com
Other
6.99k stars 744 forks source link

hotwords model runtimeerror: shape '[4, -1, 4, 128]' is invalid for input of size 512 #758

Closed lingfengchencn closed 1 year ago

lingfengchencn commented 1 year ago

codes :

input_wav = "../storage/audios/单声道16k.wav"
from pydub import AudioSegment
audio = AudioSegment.from_file(input_wav)
audio = audio.set_frame_rate(16000)

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
param_dict = dict()
hotwords_model_dir ='damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404'
paraformer_model_dir  = 'damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
paraformer_large_dir = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
hotwords_model = pipeline(
            task=Tasks.auto_speech_recognition,
            model=hotwords_model_dir,
            vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
            param_dict=param_dict)
paraformer_large_model = pipeline(
            task=Tasks.auto_speech_recognition,
            model=paraformer_large_dir,
            vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
            param_dict=param_dict)

asr_result = paraformer_large_model(audio_in = audio.raw_data,batch_size=1,batch_size_token = 1000)
print(asr_result)
hot_asr_result = hotwords_model(audio_in = audio.raw_data,batch_size=1,batch_size_token = 1000)
print(asr_result,hot_asr_result)

results : first print (only valid info):

2023-07-20 16:18:01,080 - modelscope - INFO - Decoding with pcm files ...
2023-07-20 16:18:01,080 (asr_inference_pipeline:485) INFO: Decoding with pcm files ...
time cost asr:  0.2628657817840576
{'text': '啊喂哎喂你......好拜拜好拜拜嗯嗯', 'sentences': []}
batch_size_token:  1000
time cost vad:  0.4567749500274658
batch:  4

seconds prints:


File [~/miniconda3/envs/asr/lib/python3.10/site-packages/modelscope/pipelines/audio/asr_inference_pipeline.py:575](https://vscode-remote+ssh-002dremote-002b106-002e14-002e181-002e44.vscode-resource.vscode-cdn.net/ai/asr-service/tests/~/miniconda3/envs/asr/lib/python3.10/site-packages/modelscope/pipelines/audio/asr_inference_pipeline.py:575), in AutomaticSpeechRecognitionPipeline.run_inference(self, cmd, **kwargs)
    574 def run_inference(self, cmd, **kwargs):
--> 575     asr_result = self.funasr_infer_modelscope(cmd['name_and_type'],
    576                                               cmd['raw_inputs'],
    577                                               cmd['output_dir'], cmd['fs'],
    578                                               cmd['param_dict'], **kwargs)
    580     return asr_result

File [/mnt/ai/asr-service/tests/FunASR/funasr/bin/asr_inference_launch.py:661](https://vscode-remote+ssh-002dremote-002b106-002e14-002e181-002e44.vscode-resource.vscode-cdn.net/mnt/ai/asr-service/tests/FunASR/funasr/bin/asr_inference_launch.py:661), in inference_paraformer_vad_punc.._forward(data_path_and_name_and_type, raw_inputs, output_dir_v2, fs, param_dict, **kwargs)
    659 print("batch: ", speech_j.shape[0])
    660 beg_asr = time.time()
--> 661 results = speech2text(**batch)
    662 end_asr = time.time()
    663 print("time cost asr: ", end_asr - beg_asr)

File [~/miniconda3/envs/asr/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27](https://vscode-remote+ssh-002dremote-002b106-002e14-002e181-002e44.vscode-resource.vscode-cdn.net/ai/asr-service/tests/~/miniconda3/envs/asr/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27), in _DecoratorContextManager.__call__..decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File [/mnt/ai/asr-service/tests/FunASR/funasr/bin/asr_infer.py:452](https://vscode-remote+ssh-002dremote-002b106-002e14-002e181-002e44.vscode-resource.vscode-cdn.net/mnt/ai/asr-service/tests/FunASR/funasr/bin/asr_infer.py:452), in Speech2TextParaformer.__call__(self, speech, speech_lengths, begin_time, end_time)
    450     decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
    451 else:
--> 452     decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, 
    453                                                              enc_len, 
    454                                                              pre_acoustic_embeds,
    455                                                              pre_token_length, 
    456                                                              hw_list=self.hotword_list,
    457                                                              clas_scale=self.clas_scale)
    458     decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
    460 if isinstance(self.asr_model, BiCifParaformer):

File [/mnt/ai/asr-service/tests/FunASR/funasr/models/e2e_asr_contextual_paraformer.py:365](https://vscode-remote+ssh-002dremote-002b106-002e14-002e181-002e44.vscode-resource.vscode-cdn.net/mnt/ai/asr-service/tests/FunASR/funasr/models/e2e_asr_contextual_paraformer.py:365), in NeatContextualParaformer.cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list, clas_scale)
    362     _, (h_n, _) = self.bias_encoder(hw_embed)
    363     hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
--> 365 decoder_outs = self.decoder(
    366     encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
    367 )
    368 decoder_out = decoder_outs[0]
    369 decoder_out = torch.log_softmax(decoder_out, dim=-1)

File [~/miniconda3/envs/asr/lib/python3.10/site-packages/torch/nn/modules/module.py:1110](https://vscode-remote+ssh-002dremote-002b106-002e14-002e181-002e44.vscode-resource.vscode-cdn.net/ai/asr-service/tests/~/miniconda3/envs/asr/lib/python3.10/site-packages/torch/nn/modules/module.py:1110), in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File [/mnt/ai/asr-service/tests/FunASR/funasr/models/decoder/contextual_decoder.py:284](https://vscode-remote+ssh-002dremote-002b106-002e14-002e181-002e44.vscode-resource.vscode-cdn.net/mnt/ai/asr-service/tests/FunASR/funasr/models/decoder/contextual_decoder.py:284), in ContextualParaformerDecoder.forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, contextual_info, clas_scale, return_hidden)
    282 contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
    283 contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
--> 284 cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
    286 if self.bias_output is not None:
    287     x = torch.cat([x_src_attn, cx*clas_scale], dim=2)

File [~/miniconda3/envs/asr/lib/python3.10/site-packages/torch/nn/modules/module.py:1110](https://vscode-remote+ssh-002dremote-002b106-002e14-002e181-002e44.vscode-resource.vscode-cdn.net/ai/asr-service/tests/~/miniconda3/envs/asr/lib/python3.10/site-packages/torch/nn/modules/module.py:1110), in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File [/mnt/ai/asr-service/tests/FunASR/funasr/models/decoder/contextual_decoder.py:98](https://vscode-remote+ssh-002dremote-002b106-002e14-002e181-002e44.vscode-resource.vscode-cdn.net/mnt/ai/asr-service/tests/FunASR/funasr/models/decoder/contextual_decoder.py:98), in ContextualBiasDecoder.forward(self, tgt, tgt_mask, memory, memory_mask, cache)
     96     if self.normalize_before:
     97         x = self.norm3(x)
---> 98     x =  self.dropout(self.src_attn(x, memory, memory_mask))
     99 return x, tgt_mask, memory, memory_mask, cache

File [~/miniconda3/envs/asr/lib/python3.10/site-packages/torch/nn/modules/module.py:1110](https://vscode-remote+ssh-002dremote-002b106-002e14-002e181-002e44.vscode-resource.vscode-cdn.net/ai/asr-service/tests/~/miniconda3/envs/asr/lib/python3.10/site-packages/torch/nn/modules/module.py:1110), in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File [/mnt/ai/asr-service/tests/FunASR/funasr/modules/attention.py:639](https://vscode-remote+ssh-002dremote-002b106-002e14-002e181-002e44.vscode-resource.vscode-cdn.net/mnt/ai/asr-service/tests/FunASR/funasr/modules/attention.py:639), in MultiHeadedAttentionCrossAtt.forward(self, x, memory, memory_mask)
    625 def forward(self, x, memory, memory_mask):
    626     """Compute scaled dot product attention.
    627 
    628     Args:
   (...)
    637 
    638     """
--> 639     q_h, k_h, v_h = self.forward_qkv(x, memory)
    640     q_h = q_h * self.d_k ** (-0.5)
    641     scores = torch.matmul(q_h, k_h.transpose(-2, -1))

File [/mnt/ai/asr-service/tests/FunASR/funasr/modules/attention.py:583](https://vscode-remote+ssh-002dremote-002b106-002e14-002e181-002e44.vscode-resource.vscode-cdn.net/mnt/ai/asr-service/tests/FunASR/funasr/modules/attention.py:583), in MultiHeadedAttentionCrossAtt.forward_qkv(self, x, memory)
    581 k_v = self.linear_k_v(memory)
    582 k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
--> 583 k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
    584 v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
    587 return q_h, k_h, v_h

RuntimeError: shape '[4, -1, 4, 128]' is invalid for input of size 512

and I removed vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch', from hotwords_model ,it works fine.

hotwords_model = pipeline(
            task=Tasks.auto_speech_recognition,
            model=hotwords_model_dir,
            param_dict=param_dict)

hot_asr_result = hotwords_model(audio_in = audio.raw_data,batch_size=1,batch_size_token = 1000)
print(asr_result,hot_asr_result)

result :

2023-07-20 16:31:00,693 - modelscope - INFO - Computing the result of ASR ...
2023-07-20 16:31:00,693 (asr_inference_pipeline:509) INFO: Computing the result of ASR ...
{'text': '啊喂哎...好拜拜好拜拜嗯嗯'}

I don't known why ,and shixian test my audio file and codes (with vad), he works fine ....

R1ckShi commented 1 year ago

已修复,问题出现在vad+热词模型混合定义的pipeline中,在这种情况下热词模型的推理为多batch,在热词设置为空时没有在batch维进行repeat,导致reshape时attention的H维与D维处理错误。