NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
11.84k stars 2.46k forks source link

FineTune ASR with freeze encoder #2334

Closed 2Bye closed 3 years ago

2Bye commented 3 years ago

Hello!

I try FineTune ASR with freeze encoder on Russian language.

I am taking the parameters of the optimizer, learning rate from the this paper https://arxiv.org/pdf/2005.04290.pdf and i set warmup_steps the value to 2000, but when training the model, the loss value is "buried", and WER is in one. What actions to take to solve this problem?

model_weights.ckpt is unpack stt_en_jasper10x5dr.nemo model code for training:

@hydra_runner(config_path="./", config_name="jasper_10x5dr.yaml")
def main(cfg):
    logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

    trainer = pl.Trainer(**cfg.trainer, amp_level='O1')
    exp_manager(trainer, cfg.get("exp_manager", None))
    asr_model = EncDecCTCModel(cfg=cfg.model, trainer=trainer)
    # Initialize the weights of the model from another model, if provided via config

    state_dict = torch.load('model_weights.ckpt' , map_location='cpu')
    encoder_state_dict = {key[8:] : value for key,value in state_dict.items() if key.split('.')[0] == 'encoder'}
    asr_model.encoder.load_state_dict(encoder_state_dict)
    del state_dict
    del encoder_state_dict
    torch.cuda.empty_cache()
    asr_model.encoder.freeze()

    trainer.fit(asr_model)

    if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
        gpu = 1 if cfg.trainer.gpus != 0 else 0
        test_trainer = pl.Trainer(
            gpus=gpu,
            precision=trainer.precision,
            amp_level=trainer.accelerator_connector.amp_level,
            amp_backend=cfg.trainer.get("amp_backend", "native"),
        )
        if asr_model.prepare_test(test_trainer):
            test_trainer.test(asr_model)
if __name__ == '__main__':
    main()
titu1994 commented 3 years ago

I'll be releasing a tutorial for finetuning on new language with frozen encoder in a few days, but for now I suggest freezing encoder and unfreezing the batch norm layers in the encoder.

2Bye commented 3 years ago

Okay, I try unfreeze all BatchNorm Layers in encoder.

I will wait for your tutorial for finetuning Thanks

2Bye commented 3 years ago

I'll be releasing a tutorial for finetuning on new language with frozen encoder in a few days, but for now I suggest freezing encoder and unfreezing the batch norm layers in the encoder.

I tried freezing these layers. Trained the model for 3 days. WER is also at 1, and the loss is buried

Maybe I am doing something wrong? code

def main(cfg):
    logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

    trainer = pl.Trainer(**cfg.trainer, amp_level='O1')
    exp_manager(trainer, cfg.get("exp_manager", None))
    asr_model = EncDecCTCModel(cfg=cfg.model, trainer=trainer)
    # Initialize the weights of the model from another model, if provided via config

    state_dict = torch.load('model_weights.ckpt' , map_location='cpu')
    print('load model done')
    encoder_state_dict = {key[8:] : value for key,value in state_dict.items() if key.split('.')[0] == 'encoder'}
    asr_model.encoder.load_state_dict(encoder_state_dict)
    del state_dict
    del encoder_state_dict
    torch.cuda.empty_cache()
    print('clear GPU memory')

    asr_model.encoder.freeze()
    True_class = type(asr_model.encoder.encoder[0].mconv[1])

    for i in asr_model.encoder.encoder:
        for j in i.mconv:
            if isinstance(j, True_class):
                j.requires_grad_ = True

    trainer.fit(asr_model)

    if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
        gpu = 1 if cfg.trainer.gpus != 0 else 0
        test_trainer = pl.Trainer(
            gpus=gpu,
            precision=trainer.precision,
            amp_level=trainer.accelerator_connector.amp_level,
            amp_backend=cfg.trainer.get("amp_backend", "native"),
        )
        if asr_model.prepare_test(test_trainer):
            test_trainer.test(asr_model)

if __name__ == '__main__':
    main()  # noqa pylint: disable=no-value-for-parameter
titu1994 commented 3 years ago

You are unfreezing entire encoder again with that true class thing - you only need to unfreeze the batch norm within the module. I have a PR currently under review, see it's section to unfreeze batch norm.

2Bye commented 3 years ago

I learn your google.colab.notebook, skip few steps on dataset for vocabulary, because i give this information in config file.

I get this graphics

image image

I think something is wrong, namely that it is too slow

import pytorch_lightning as pl
from omegaconf import OmegaConf
import torch

from nemo.collections.asr.models import EncDecCTCModel
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
import torch.nn as nn

def enable_bn_se(m):
    if type(m) == nn.BatchNorm1d:
        m.train()
        for param in m.parameters():
            param.requires_grad_(True)

    if 'SqueezeExcite' in type(m).__name__:
        m.train()
        for param in m.parameters():
            param.requires_grad_(True)

@hydra_runner(config_path="./", config_name="jasper_10x5dr.yaml")
def main(cfg):
    logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

    trainer = pl.Trainer(**cfg.trainer, amp_level='O1')
    exp_manager(trainer, cfg.get("exp_manager", None))
    asr_model = EncDecCTCModel(cfg=cfg.model, trainer=trainer)

    state_dict = torch.load('model_weights.ckpt', map_location='cpu')
    encoder_state_dict = {key[8:] : value for key,value in state_dict.items() if key.split('.')[0] == 'encoder'}
    asr_model.encoder.load_state_dict(encoder_state_dict)
    del state_dict
    del encoder_state_dict

    asr_model.encoder.freeze()
    asr_model.encoder.apply(enable_bn_se)
    logging.info("Model encoder has been frozen, and batch normalization has been unfrozen")

    trainer.fit(asr_model)

    if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
        gpu = 1 if cfg.trainer.gpus != 0 else 0
        test_trainer = pl.Trainer(
            gpus=gpu,
            precision=trainer.precision,
            amp_level=trainer.accelerator_connector.amp_level,
            amp_backend=cfg.trainer.get("amp_backend", "native"),
        )
        if asr_model.prepare_test(test_trainer):
            test_trainer.test(asr_model)

if __name__ == '__main__':
    main()  # noqa pylint: disable=no-value-for-parameter
titu1994 commented 3 years ago

I would advise you to go over every single step by step.of the QuartzNet portion and match that to your script. I see you don't disable normalize_transcripts, don't use pretrained checkpoint from_pretrained, etc

If after all that it still doesn't train then data is insufficient or there is some other issue in labeling