NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
11.95k stars 2.48k forks source link

float16 is not supported by the preprocessor in Conformer CTC Large #5646

Closed galv closed 1 year ago

galv commented 1 year ago

Describe the bug

Not sure if this is intended to be supported or not, but I don't seem to be able to run the entire Conformer CTC Large in fp16 format in CUDA. The problem seems to occur because of a missing op in the preprocessor. This seem error is mentioned here: https://github.com/pytorch/pytorch/issues/71680

Perhaps nemo can enable the preprocessor to run in fp32 while the rest runs in fp16 as a work around?

Steps/Code to reproduce bug

import nemo.collections.asr as nemo_asr
import torch

asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name="stt_en_conformer_ctc_small", map_location=torch.device("cuda"))
asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
asr_model.half()
asr_model.eval()
asr_model.encoder.freeze()
asr_model.decoder.freeze()

length = 16_000 * 2
input_signal = torch.randn((1, length), dtype=torch.float16)
input_signal_length = torch.tensor([length], dtype=torch.int64)

# This will crash
_ = asr_model.forward(input_signal=input_signal, input_signal_length=input_signal_length)

Error message is this:

Traceback (most recent call last):
  File "/home/dgalvez/scratch/code/asr/riva-asrlib-decoder/src/riva/asrlib/decoder/reproducer.py", line 16, in <module>
    _ = asr_model.forward(input_signal=input_signal, input_signal_length=input_signal_length)
  File "/home/dgalvez/scratch/miniconda3/envs/wfst/lib/python3.9/site-packages/nemo/core/classes/common.py", line 1084, in __call__
    outputs = wrapped(*args, **kwargs)
  File "/home/dgalvez/scratch/miniconda3/envs/wfst/lib/python3.9/site-packages/nemo/collections/asr/models/ctc_models.py", line 540, in forward
    processed_signal, processed_signal_length = self.preprocessor(
  File "/home/dgalvez/scratch/miniconda3/envs/wfst/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/dgalvez/scratch/miniconda3/envs/wfst/lib/python3.9/site-packages/nemo/core/classes/common.py", line 1084, in __call__
    outputs = wrapped(*args, **kwargs)
  File "/home/dgalvez/scratch/miniconda3/envs/wfst/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/dgalvez/scratch/miniconda3/envs/wfst/lib/python3.9/site-packages/nemo/collections/asr/modules/audio_preprocessing.py", line 85, in forward
    processed_signal, processed_length = self.get_features(input_signal, length)
  File "/home/dgalvez/scratch/miniconda3/envs/wfst/lib/python3.9/site-packages/nemo/collections/asr/modules/audio_preprocessing.py", line 268, in get_features
    return self.featurizer(input_signal, length)
  File "/home/dgalvez/scratch/miniconda3/envs/wfst/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/dgalvez/scratch/miniconda3/envs/wfst/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/dgalvez/scratch/miniconda3/envs/wfst/lib/python3.9/site-packages/nemo/collections/asr/parts/preprocessing/features.py", line 352, in forward
    x = self.stft(x)
  File "/home/dgalvez/scratch/miniconda3/envs/wfst/lib/python3.9/site-packages/nemo/collections/asr/parts/preprocessing/features.py", line 244, in <lambda>
    self.stft = lambda x: torch.stft(
  File "/home/dgalvez/scratch/miniconda3/envs/wfst/lib/python3.9/site-packages/torch/functional.py", line 630, in stft
    input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
RuntimeError: "reflection_pad1d" not implemented for 'Half'

Expected behavior

It would be great if float16 worked out of the box for NeMo

Environment overview (please complete the following information)

Environment details

If NVIDIA docker image is used you don't need to specify these. Otherwise, please provide:

Additional context

titu1994 commented 1 year ago

You're supposed to use torch.amp.autocast(). Model.half() will push all tensors to fp16 including batch norm and many other operators that are not supported well in fp16 - for example torch.stft()

Even if stft() worked, batch norm would fail eventually

github-actions[bot] commented 1 year ago

This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.

github-actions[bot] commented 1 year ago

This issue was closed because it has been inactive for 7 days since being marked as stale.