NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
11.56k stars 2.42k forks source link

conv asr onnx export no signal length input #2317

Closed Slyne closed 2 years ago

Slyne commented 3 years ago

Describe the bug I'm trying to export Quartznet model to onnx by using asr_model.export() function. However, the generated onnx model doesn't have the input length as an input and it only has 'audio_signal'. Therefore after converting the onnx model to trt engine, I find the inference results have a WER much worse than inferencing using NeMo directly and the trt engine is vulnerable to different batch sizes. Could you help confirm whether this is a bug ?

I find that in conformer encoder, the input_example used to generate onnx model has input example length. However, in conv asr, there's only input example.

Steps/Code to reproduce bug

# generate calibrated model
python3 speech_to_text_calibrate.py --dataset=/raid/data/LibriSpeech/manifest.dev-other --asr_model=QuartzNet15x5Base-En 

# generate onnx model
python3 speech_to_text_quant_infer.py --dataset=/raid/data/LibriSpeech/manifest.dev-other --asr_model=QuartzNet15x5Base-En-max-256.nemo --onnx
# Got WER  10.7x%

python3 speech_to_text_quant_infer_trt.py --dataset=/raid/data/LibriSpeech/manifest.dev-other --asr_model=QuartzNet15x5Base-En-max-256.nemo --asr_onnx=./QuartzNet15x5Base-En-max-256.onnx --qat
# Got WER 11.x%

Expected behavior Expect the speech_to_text_quant_infer.py to have the same result as speech_to_text_quant_infer_trt.py

Environment overview (please complete the following information) nvcr.io/nvidia/pytorch:21.05-py3

titu1994 commented 3 years ago

I'm not very familiar with ONNX, @borisfom could you take a look at this ?

VahidooX commented 3 years ago

@borisfom would comment on the onnx export and why length is ignored. But AFAIK with QuartzNet you do not need to pass the lengths as they are not used at all in the forward pass (unless you enable the masked convolutions). You just need to mask-out the outputs in the decoding procedure. For Conformer, lengths are necessary as we need to mask the paddings in self-attention.

Slyne commented 3 years ago

@VahidooX I checked the config file for Quartznet and it uses conv_mask: True. It would be better to give users an option to decide whether to have input length or not.

Slyne commented 3 years ago

@borisfom Any update ?

Slyne commented 3 years ago

@titu1994 Do you mind I add seq length to

https://github.com/NVIDIA/NeMo/blob/7ef1782b94386629fbfafece72f618096c33a9f3/nemo/collections/asr/modules/conv_asr.py#L259

titu1994 commented 3 years ago

I don't know the impact this will have on downstream tools that require the onnx file in current format.

@borisfom please advise here

borisfom commented 3 years ago

@Slyne : the change needs to be more extensive QuartzNet removes masked convolutions for inference, and you would need to keep them, too : def _prepare_for_export(self, kwargs): m_count = 0 for m in self.modules(): if isinstance(m, MaskedConv1d): m.use_mask = False m_count += 1 Exportable._prepare_for_export(self, kwargs) logging.warning(f"Turned off {m_count} masked convolutions")

This has to be controlled by some other flag, not the main conv_mask - because there are currently many nets that do have this flag as True for training but do not need that for inference. For your local experiments, you can try hacking it out - please also note we originally had to use removal of masked convolutions to get the code exported in ONNX at all - that may have been fixed, but I did not check.

Slyne commented 3 years ago

@borisfom "This has to be controlled by some other flag, not the main conv_mask - because there are currently many nets that do have this flag as True for training but do not need that for inference."

Do you mean the sequence length will not affect the inference procedure for quartznet ? Or we can just trim the quartznet output by sequence length?