NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
12.01k stars 2.5k forks source link

Using PyTorch Lightning Quantization Aware Training callback #4617

Closed yttuncel closed 2 years ago

yttuncel commented 2 years ago

Can we use PL's QuantizationAwareTraining and ModelPruning callbacks during NeMo model training?

To quickly test this, I tried to add the QAT callback to the trainer in speech_to_text_ctc.py script to train a quick Quartznet5x5 model:

callbacks = [QuantizationAwareTraining(qconfig='qnnpack')]
trainer = pl.Trainer(callbacks=callbacks, **cfg.trainer)

and I get the following error:

Error executing job with overrides: []
Traceback (most recent call last):
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/ytuncel/.vscode-server/extensions/ms-python.python-2022.10.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/ytuncel/.vscode-server/extensions/ms-python.python-2022.10.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/ytuncel/.vscode-server/extensions/ms-python.python-2022.10.1/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/ytuncel/.vscode-server/extensions/ms-python.python-2022.10.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/ytuncel/.vscode-server/extensions/ms-python.python-2022.10.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/ytuncel/.vscode-server/extensions/ms-python.python-2022.10.1/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/home/ytuncel/tiny_asr/src/tiny_asr/script.py", line 108, in <module>
    main()  # noqa pylint: disable=no-value-for-parameter
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/nemo/core/config/hydra_runner.py", line 104, in wrapper
    _run_hydra(
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/hydra/_internal/utils.py", line 377, in _run_hydra
    run_and_report(
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/hydra/_internal/utils.py", line 214, in run_and_report
    raise ex
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/hydra/_internal/utils.py", line 211, in run_and_report
    return func()
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/hydra/_internal/utils.py", line 378, in <lambda>
    lambda: hydra.run(
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/hydra/_internal/hydra.py", line 111, in run
    _ = ret.return_value
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/hydra/core/utils.py", line 233, in return_value
    raise self._return_value
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/hydra/core/utils.py", line 160, in run_job
    ret.return_value = task_function(task_cfg)
  File "/home/ytuncel/tiny_asr/src/tiny_asr/script.py", line 103, in main
    trainer.fit(asr_model)
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
    self._call_and_handle_interrupt(
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1236, in _run
    results = self._run_stage()
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1323, in _run_stage
    return self._run_train()
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1345, in _run_train
    self._run_sanity_check()
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1413, in _run_sanity_check
    val_loop.run()
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 154, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 128, in advance
    output = self._evaluation_step(**kwargs)
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 226, in _evaluation_step
    output = self.trainer._call_strategy_hook("validation_step", *kwargs.values())
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1765, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 344, in validation_step
    return self.model.validation_step(*args, **kwargs)
  File "/home/ytuncel/miniconda3/envs/asr/lib/python3.8/site-packages/nemo/collections/asr/models/ctc_models.py", line 621, in validation_step
    log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)
TypeError: wrapper() got an unexpected keyword argument 'input_signal'

Are there any examples that show how we can use the Pruning and QAT tools in pytorch with NeMo?

titu1994 commented 2 years ago

We do not support PTL QAT, most of Nemo classes cannot be used with pickling

titu1994 commented 2 years ago

Well, I should more say we don't test for it and it may work but we cannot offer much support for it c

github-actions[bot] commented 2 years ago

This issue is stale because it has been open for 60 days with no activity.

github-actions[bot] commented 2 years ago

This issue was closed because it has been inactive for 7 days since being marked as stale.