NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
11.96k stars 2.49k forks source link

Quantization for CitriNet-1024-Gamma-0.25 #2830

Closed meghmak13 closed 3 years ago

meghmak13 commented 3 years ago

Describe the bug I was trying to convert the CitriNet-1024-Gamma-0.25 to TensorRT Quantized Model, while following the example workflow given under the quantization section but facing the following issue

Steps/Code to reproduce bug

root@nvidia-DGX-Station:/workspace/nemo/NeMo/examples/asr/quantization# python3 speech_to_text_calibrate.py --asr_model=citrinet-1024-gamma-0.25.nemo --dataset=test_hindi_navana.json ################################################################################

WARNING, path does not exist: KALDI_ROOT=/mnt/matylda5/iveselyk/Tools/kaldi-trunk

(please add 'export KALDI_ROOT=' in your $HOME/.profile)

(or run as: KALDI_ROOT= python .py)

################################################################################

[NeMo W 2021-09-16 14:04:05 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/torchaudio-0.7.0a0+42d447d-py3.8-linux-x86_64.egg/torchaudio/backend/utils.py:53: UserWarning: "sox" backend is being deprecated. The default backend will be changed to "sox_io" backend in 0.8.0 and "sox" backend will be removed in 0.9.0. Please migrate to "sox_io" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail. warnings.warn(

[NeMo W 2021-09-16 14:04:06 experimental:27] Module <class 'nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset'> is experimental, not ready for production and is not fully supported. Use at your own risk. [NeMo I 2021-09-16 14:04:07 speech_to_text_calibrate:80] Using local ASR model from citrinet-1024-gamma-0.25.nemo [NeMo I 2021-09-16 14:04:19 mixins:147] Tokenizer SentencePieceTokenizer initialized with 72 tokens [NeMo W 2021-09-16 14:04:19 modelPT:138] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader. Train config : manifest_filepath: ./manifest/hindi/train_hindi_data.json sample_rate: 16000 batch_size: 16 trim_silence: true max_duration: 20.0 shuffle: true is_tarred: false tarred_audio_filepaths: null use_start_end_token: true num_workers: 8 pin_memory: true

[NeMo W 2021-09-16 14:04:19 modelPT:145] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). Validation config : manifest_filepath: ./manifest/hindi/test_hindi_data.json sample_rate: 16000 batch_size: 8 shuffle: false use_start_end_token: true num_workers: 8 pin_memory: true trim_silence: true

[NeMo W 2021-09-16 14:04:19 modelPT:151] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a valid configuration file to setup the test data loader(s). Test config : manifest_filepath: ./manifest/hindi/test_hindi_data.json sample_rate: 16000 batch_size: 8 shuffle: false use_start_end_token: true num_workers: 8 pin_memory: true trim_silence: true

[NeMo I 2021-09-16 14:04:19 features:252] PADDING: 16 [NeMo I 2021-09-16 14:04:19 features:269] STFT using torch I0916 14:04:20.338965 140410652182336 _utils.py:72] Input is fake quantized to 8 bits in QuantConv1d with axis None! I0916 14:04:20.339134 140410652182336 _utils.py:75] Weight is fake quantized to 8 bits in QuantConv1d with axis 0! I0916 14:04:20.339357 140410652182336 tensor_quantizer.py:105] Creating Max calibrator I0916 14:04:20.339587 140410652182336 tensor_quantizer.py:105] Creating Max calibrator I0916 14:04:20.341764 140410652182336 _utils.py:72] Input is fake quantized to 8 bits in QuantConv1d with axis None! I0916 14:04:20.341858 140410652182336 _utils.py:75] Weight is fake quantized to 8 bits in QuantConv1d with axis 0! I0916 14:04:20.342052 140410652182336 tensor_quantizer.py:105] Creating Max calibrator I0916 14:04:20.342261 140410652182336 tensor_quantizer.py:105] Creating Max calibrator I0916 14:04:20.343992 140410652182336 _utils.py:126] Input is fake quantized to 8 bits in QuantAdaptiveAvgPool1d with axis None! I0916 14:04:20.344176 140410652182336 tensor_quantizer.py:105] Creating Max calibrator [NeMo W 2021-09-16 14:04:20 modelPT:138] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader. Train config : manifest_filepath: ./manifest/hindi/train_hindi_data.json sample_rate: 16000 batch_size: 16 trim_silence: true max_duration: 20.0 shuffle: true is_tarred: false tarred_audio_filepaths: null use_start_end_token: true num_workers: 8 pin_memory: true

[NeMo W 2021-09-16 14:04:20 modelPT:145] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). Validation config : manifest_filepath: ./manifest/hindi/test_hindi_data.json sample_rate: 16000 batch_size: 8 shuffle: false use_start_end_token: true num_workers: 8 pin_memory: true trim_silence: true

[NeMo W 2021-09-16 14:04:20 modelPT:151] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a valid configuration file to setup the test data loader(s). Test config : manifest_filepath: ./manifest/hindi/test_hindi_data.json sample_rate: 16000 batch_size: 8 shuffle: false use_start_end_token: true num_workers: 8 pin_memory: true trim_silence: true

[NeMo I 2021-09-16 14:04:20 features:252] PADDING: 16 [NeMo I 2021-09-16 14:04:20 features:269] STFT using torch I0916 14:04:20.691268 140410652182336 _utils.py:72] Input is fake quantized to 8 bits in QuantConv1d with axis None! I0916 14:04:20.691397 140410652182336 _utils.py:75] Weight is fake quantized to 8 bits in QuantConv1d with axis 0! I0916 14:04:20.691546 140410652182336 tensor_quantizer.py:105] Creating Max calibrator I0916 14:04:20.691687 140410652182336 tensor_quantizer.py:105] Creating Max calibrator I0916 14:04:20.693689 140410652182336 _utils.py:72] Input is fake quantized to 8 bits in QuantConv1d with axis None! I0916 14:04:20.693774 140410652182336 _utils.py:75] Weight is fake quantized to 8 bits in QuantConv1d with axis 0! I0916 14:04:20.693904 140410652182336 tensor_quantizer.py:105] Creating Max calibrator I0916 14:04:20.694039 140410652182336 tensor_quantizer.py:105] Creating Max calibrator I0916 14:04:20.695321 140410652182336 _utils.py:126] Input is fake quantized to 8 bits in QuantAdaptiveAvgPool1d with axis None! I0916 14:04:20.695476 140410652182336 tensor_quantizer.py:105] Creating Max calibrator Traceback (most recent call last): File "/opt/conda/lib/python3.8/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 62, in _call_target return target(*args, **kwargs) File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/modules/conv_asr.py", line 166, in init JasperBlock( File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/parts/submodules/jasper.py", line 658, in init SqueezeExcite( File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/parts/submodules/jasper.py", line 342, in init self.change_context_window(context_window=context_window) File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/parts/submodules/jasper.py", line 411, in change_context_window if not isinstance(self.pool, quant_nn.QuantAdaptiveAvgPool1d(1)): TypeError: isinstance() arg 2 must be a type or tuple of types

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "speech_to_text_calibrate.py", line 160, in main() # noqa pylint: disable=no-value-for-parameter File "speech_to_text_calibrate.py", line 84, in main asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model, override_config_path=asr_model_cfg) File "/opt/conda/lib/python3.8/site-packages/nemo/core/classes/modelPT.py", line 481, in restore_from return cls._default_restore_from(restore_path, override_config_path, map_location, strict, return_config) File "/opt/conda/lib/python3.8/site-packages/nemo/core/classes/modelPT.py", line 434, in _default_restore_from instance = cls.from_config_dict(config=conf) File "/opt/conda/lib/python3.8/site-packages/nemo/core/classes/common.py", line 472, in from_config_dict instance = cls(cfg=config) File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/models/ctc_models.py", line 155, in init self.encoder = EncDecCTCModel.from_config_dict(self._cfg.encoder) File "/opt/conda/lib/python3.8/site-packages/nemo/core/classes/common.py", line 437, in from_config_dict instance = hydra.utils.instantiate(config=config) File "/opt/conda/lib/python3.8/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 180, in instantiate return instantiate_node(config, args, recursive=recursive, convert=convert) File "/opt/conda/lib/python3.8/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 249, in instantiate_node return _call_target(target, args, *kwargs) File "/opt/conda/lib/python3.8/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 64, in _call_target raise type(e)( File "/opt/conda/lib/python3.8/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 62, in _call_target return target(args, **kwargs) File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/modules/conv_asr.py", line 166, in init JasperBlock( File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/parts/submodules/jasper.py", line 658, in init SqueezeExcite( File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/parts/submodules/jasper.py", line 342, in init self.change_context_window(context_window=context_window) File "/opt/conda/lib/python3.8/site-packages/nemo/collections/asr/parts/submodules/jasper.py", line 411, in change_context_window if not isinstance(self.pool, quant_nn.QuantAdaptiveAvgPool1d(1)): TypeError: Error instantiating 'nemo.collections.asr.modules.conv_asr.ConvASREncoder' : isinstance() arg 2 must be a type or tuple of types

Environment details

If NVIDIA docker image is used you don't need to specify these. Otherwise, please provide: NGC: nvcr.io/nvidia/nemo:1.2.0

Additional context

Add any other context about the problem here. Example: V100

titu1994 commented 3 years ago

That's a bug. I'll patch it in the coming release

titu1994 commented 3 years ago

Closed via https://github.com/NVIDIA/NeMo/pull/3062