NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
12.06k stars 2.51k forks source link

Is it possible to load encoder from Rnnt to Cache-aware streaming conformer model #7842

Closed nabil6391 closed 1 year ago

nabil6391 commented 1 year ago

Describe the bug

Thanks to you guys I have succesfully finetuned a fast conformer transducer model and it is performing quite well. I wanted to see if I am able to continue finetuning for the Cache Aware Streaming Conformer model which looks it might be better for realtime asr.

Steps/Code to reproduce bug

I tried to load the dict from the finetuned Rnnt model into a streaming fast conformer moddel:

model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.from_pretrained(model_name="stt_en_fastconformer_hybrid_large_streaming_multi")

model2 = nemo_asr.models.EncDecRNNTBPEModel.restore_from(
    saved_model + "/model3.nemo"
)
cfg = OmegaConf.create({ 'init_from_nemo_model': saved_model + "/model3.nemo",})
model.maybe_init_from_pretrained_checkpoint(cfg)

But it gives this error output:

RuntimeError: Error(s) in loading state_dict for EncDecHybridRNNTCTCBPEModel:
    size mismatch for encoder.pre_encode.out.weight: copying a param with shape torch.Size([512, 2560]) from checkpoint, the shape in current model is torch.Size([512, 2816]).

Is there any way load the Rnnt Encoder into the Streaming one, perhaps by remodifying or something else?

Environment details


[NeMo I 2023-11-01 23:33:17 mixins:170] Tokenizer SentencePieceTokenizer initialized with 1024 tokens
[NeMo W 2023-11-01 23:33:18 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: drive/Shareddrives/QuranResources/models/nemo-conformer-text-quran/train_manifest.json
    sample_rate: 16000
    batch_size: 12
    shuffle: true
    num_workers: 0
    pin_memory: true
    use_start_end_token: false
    trim_silence: false
    max_duration: 20.0
    min_duration: 0.1
    is_tarred: false
    tarred_audio_filepaths: null
    shuffle_n: 2048
    bucketing_strategy: fully_randomized
    bucketing_batch_size: null
    channel_selector: 0

[NeMo W 2023-11-01 23:33:18 modelPT:168] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    manifest_filepath: drive/Shareddrives/QuranResources/models/nemo-conformer-text-quran/dev_manifest.json
    sample_rate: 16000
    batch_size: 12
    shuffle: false
    num_workers: 0
    pin_memory: true
    use_start_end_token: false
    max_duration: 20
    channel_selector: 0
    is_tarred: false

[NeMo W 2023-11-01 23:33:18 modelPT:174] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a valid configuration file to setup the test data loader(s).
    Test config : 
    manifest_filepath: drive/Shareddrives/QuranResources/models/nemo-conformer-text-quran/test_manifest.json
    sample_rate: 16000
    batch_size: 12
    shuffle: false
    num_workers: 0
    pin_memory: true
    use_start_end_token: false
    channel_selector: 0

[NeMo I 2023-11-01 23:33:18 features:289] PADDING: 0
[NeMo W 2023-11-01 23:33:18 nemo_logging:349] [/opt/homebrew/Caskroom/miniconda/base/envs/nemo/lib/python3.10/site-packages/torch/nn/modules/rnn.py:82](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniconda/base/envs/nemo/lib/python3.10/site-packages/torch/nn/modules/rnn.py:82): UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1
      warnings.warn("dropout option adds dropout after all but last "

[NeMo I 2023-11-01 23:33:19 rnnt_models:211] Using RNNT Loss : warprnnt_numba
    Loss warprnnt_numba_kwargs: {'fastemit_lambda': 0.0, 'clamp': -1.0}
[NeMo I 2023-11-01 23:33:19 save_restore_connector:249] Model EncDecRNNTBPEModel was successfully restored from [/Users/nhossain/StudioProjects/AndroidStudioProjects/GenerateQuranTimeStamps/nemo/nemo-conformer-text-quran/model3.nemo.](https://file+.vscode-resource.vscode-cdn.net/Users/nhossain/StudioProjects/AndroidStudioProjects/GenerateQuranTimeStamps/nemo/nemo-conformer-text-quran/model3.nemo.)

`
titu1994 commented 1 year ago

Cache aware Conformer is an encoder level module, not a decoder level module. So you cannot swap fast Conformer (an encoder) with cache aware Conformer (also an encoder)

nabil6391 commented 1 year ago

Thanks for the fast response, I tried finetuning the model without the encoder, but seems like its not converging as fast as the RNNT, might chanege if I would have waited more I guess.

Anyways I will stick with the buffered Rnnt for realtime now, thanks