espnet / espnet_onnx

Onnx wrapper for espnet infrernce model
MIT License
156 stars 23 forks source link

ASR onnx export accuracy drop #49

Open yuekaizhang opened 2 years ago

yuekaizhang commented 2 years ago

Hi, I was wondering if you test the accuracy using espnet onnx. I have encountered accuracy drop issue from ~5% wer to 12% wer after using onnx. Just wondering if anyone encounter the same issue.

sw005320 commented 2 years ago

Which models are you using? Unfortunately, this porting requires manual implementation, and degradations would happen due to this.

In addition, to debug all functions, we are also trying to implement a test to cover many functions to produce the same results.

yuekaizhang commented 2 years ago

Which models are you using? Unfortunately, this porting requires manual implementation, and degradations would happen due to this.

In addition, to debug all functions, we are also trying to implement a test to cover many functions to produce the same results.

Hi Shinji, I am using aishell model from https://zenodo.org/record/4105763#.YzELWtJBxAJ. I have found that the issue may be caused by the Legacy Relpos attention layer, which gave me different results between onnx and .pt model file.

sw005320 commented 2 years ago

Thanks! @Masao-Someki, can you take a look at the model?

Masao-Someki commented 2 years ago

Hi @yuekaizhang, thank you for reporting the issue! I recently found that the current beam search implementation of espnet_onnx might cause a large accuracy loss, and this case might be caused by this bug in beam search... The parity issue of legacy conformer is not solved yet, but I think it is not the main reason for such a large accuracy drop. I will fix this issue as soon as possible. I think I can fix this in this weekend...

pengaoao commented 2 years ago

I am exporting the same model, the onnx cer is 11.7%, while the torch infer cer is 4.7%

pengaoao commented 2 years ago

I am exporting the same model, the onnx cer is 11.7%, while the torch infer cer is 4.7%

@Masao-Someki Could you plz look at this model?

Masao-Someki commented 2 years ago

@yuekaizhang @pengaoao I'm sorry for the late replay, I fixed the beam_search-related bug. Would you check the accuracy with the latest espnet_onnx? Note that ctc_weight=0.3 in the model, so please check your configuration file in ~/.cache/espnet_onnx/<tag_name>/config.yml. ctc_weight is written at the bottom of the file.

pengaoao commented 2 years ago

@yuekaizhang @pengaoao I'm sorry for the late replay, I fixed the beam_search-related bug. Would you check the accuracy with the latest espnet_onnx? Note that ctc_weight=0.3 in the model, so please check your configuration file in ~/.cache/espnet_onnx/<tag_name>/config.yml. ctc_weight is written at the bottom of the file.

thanks, Im trying to test on the aishell, there is something wrong

2022-10-09 02:24:56.152717163 [W:onnxruntime:, execution_frame.cc:812 VerifyOutputSizes] Expected shape from model of {1,24} does not match actual shape of {1,108} for output encoder_out_lens 2022-10-09 02:24:56.305905698 [W:onnxruntime:, execution_frame.cc:812 VerifyOutputSizes] Expected shape from model of {1,2,256} does not match actual shape of {20,2,256} for output out_cache_0 2022-10-09 02:24:56.306356162 [E:onnxruntime:, sequential_executor.cc:368 Execute] Non-zero status code returned while running Reshape node. Name:'Reshape_296' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:41 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, onnxruntime::TensorShapeVector&, bool) gsl::narrow_cast(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{20,2,256}, requested shape:{1,2,4,64}

Traceback (most recent call last): File "infer2.py", line 20, in nbest = speech2text(y) File "/usr/local/python3.7.5/lib/python3.7/site-packages/espnet_onnx/asr/asr_model.py", line 84, in call nbest_hyps = self.beam_search(enc[0])[:1] File "/usr/local/python3.7.5/lib/python3.7/site-packages/espnet_onnx/asr/beam_search/beam_search.py", line 334, in call best = self.search(running_hyps, x) File "/usr/local/python3.7.5/lib/python3.7/site-packages/espnet_onnx/asr/beam_search/batch_beamsearch.py", line 195, in search [x for in range(n_batch)]).reshape(n_batch, *x.shape)) File "/usr/local/python3.7.5/lib/python3.7/site-packages/espnet_onnx/asr/beam_search/batch_beam_search.py", line 136, in score_full scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x) File "/usr/local/python3.7.5/lib/python3.7/site-packages/espnet_onnx/asr/model/decoders/xformer.py", line 86, in batch_score input_dict File "/usr/local/python3.7.5/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 200, in run return self._sess.run(output_names, input_feed, run_options) onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'Reshape_296' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:41 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, onnxruntime::TensorShapeVector&, bool) gsl::narrow_cast(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{20,2,256}, requested shape:{1,2,4,64}

export script: from espnet_onnx.export import ASRModelExport

m = ASRModelExport() m.export_from_zip('asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp_valid.acc.ave.zip',tag_name='asr_train_asr')

infer script: import librosa from espnet_onnx import Speech2Text

from pyacl.acl_infer import init_acl, release_acl

from tqdm import tqdm import os,re

def findAllFile(base): for root, ds, fs in os.walk(base): for f in fs: if re.match(r'.\d.', f): fullname = os.path.join(root, f) yield fullname

init_acl(0)

speech2text = Speech2Text(tag_name='')

speech2text = Speech2Text(model_dir='/root/.cache/espnet_onnx/asr_train_asr/') path = "sample.wav" y, sr = librosa.load(path, sr=16000) nbest = speech2text(y) print(nbest)

pengaoao commented 2 years ago

accuracy

I change the torch version to 1.12, the error disappear, I m testing the acc

pengaoao commented 2 years ago

I test on 3000 files with ctc_weight=0.3, the cer is 14%, the problem seems not be solved

Masao-Someki commented 2 years ago

@pengaoao Thank you for reporting, I investigated bugs and found that there is a parity issue with TransformerDecoder. I fixed it in #52, and got the same result sentence with some sample wav files from aishell. Would you test again?

pengaoao commented 2 years ago

CER 7.43% with ctc=0.5, cer 8.92% with CTC=0.3, still something wrong

pengaoao commented 1 year ago

Except for the ctc=0.3, I set the decoder=0.7 lm=0.3 and then the cer is right.

Masao-Someki commented 1 year ago

Thank you @pengaoao, we have fixed bugs in the stft frontend and now we can get exactly the same result. @yuekaizhang, would you check your performance with the latest version?

sanjuktasr commented 1 year ago

hi @Masao-Someki, could you please tell me if similar problems are bugging the streaming conformer?

Masao-Someki commented 1 year ago

@sanjuktasr We can check if we have parity issues with espnet_onnx by running the test_inference_asr test, and the contextual conformer block does not have a parity issue.