wenet-e2e / wenet

Production First and Production Ready End-to-End Speech Recognition Toolkit
https://wenet-e2e.github.io/wenet/
Apache License 2.0
4.18k stars 1.08k forks source link

error in _update_kv_and_cache with conformer model #2653

Open anjul1008 opened 6 days ago

anjul1008 commented 6 days ago

i tried exporting the stream conformer model to onnx format with below parameters.

python3 wenet/bin/export_onnx_gpu.py  --config=$model_dir/train.yaml --checkpoint=$model_dir/final.pt --cmvn_file=$model_dir/global_cmvn --ctc_weight=0.5 --output_onnx_dir=$onnx_model_dir  --streaming --return_ctc_logprobs   --fp16

Conformer model arch

encoder: conformer
encoder_conf:
  activation_type: swish
  attention_dropout_rate: 0.1
  attention_heads: 8
  causal: true
  cnn_module_kernel: 15 #31
  cnn_module_norm: layer_norm
  dropout_rate: 0.1
  gradient_checkpointing: true
  input_layer: conv2d
  linear_units: 2048
  normalize_before: true
  num_blocks: 12
  output_size: 512
  pos_enc_layer_type: rel_pos
  positional_dropout_rate: 0.1
  selfattention_layer_type: rel_selfattn
  use_cnn_module: true
  use_dynamic_chunk: true
  use_dynamic_left_chunk: false

Error Produced graph, torch_out = _trace_and_get_graph_from_model(model, args) File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 915, in _trace_and_get_graph_from_model trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 1291, in _get_trace_graph outs = ONNXTracedModule( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl return forward_call(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 138, in forward graph, out = torch._C._create_graph_by_tracing( File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 129, in wrapper outs.append(self.inner(trace_inputs)) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl return forward_call(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1500, in _slow_forward result = self.forward(*input, kwargs) File "/media/hulk2/BigData2/wenet_23_jan_2024/examples/reverie/v5/s0/wenet/bin/export_onnxgpu.py", line 172, in forward xs, , new_att_cache, new_cnn_cache = layer( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl return forward_call(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1500, in _slow_forward result = self.forward(*input, kwargs) File "/media/hulk2/BigData2/wenet_23_jan_2024/examples/reverie/v5/s0/wenet/transformer/encoder_layer.py", line 234, in forward x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl return forward_call(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1500, in _slow_forward result = self.forward(*input, **kwargs) File "/media/hulk2/BigData2/wenet_23_jan_2024/examples/reverie/v5/s0/wenet/transformer/attention.py", line 399, in forward k, v, new_cache = self._update_kv_and_cache(k, v, cache) File "/media/hulk2/BigData2/wenet_23_jan_2024/examples/reverie/v5/s0/wenet/transformer/attention.py", line 209, in _update_kv_and_cache key_cache, value_cache = cache ValueError: too many values to unpack (expected 2)

I tried to fix it but no luck yet. Anyone help would be appreciated. thanks in advance.

xingchensong commented 6 days ago

The Cache API changed in this pull request (https://github.com/wenet-e2e/wenet/pull/2481). It is not modified (for export_onn_gpu.py) accordingly.

Mddct commented 5 days ago

please refer: https://github.com/wenet-e2e/wenet/pull/2654

Sorry, I don't have a GPU environment now. Please modify it according to this PR and do some recognization to help verify it.