Closed 2Bye closed 3 years ago
I'll be releasing a tutorial for finetuning on new language with frozen encoder in a few days, but for now I suggest freezing encoder and unfreezing the batch norm layers in the encoder.
Okay, I try unfreeze all BatchNorm Layers in encoder.
I will wait for your tutorial for finetuning Thanks
I'll be releasing a tutorial for finetuning on new language with frozen encoder in a few days, but for now I suggest freezing encoder and unfreezing the batch norm layers in the encoder.
I tried freezing these layers. Trained the model for 3 days. WER is also at 1, and the loss is buried
Maybe I am doing something wrong? code
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
trainer = pl.Trainer(**cfg.trainer, amp_level='O1')
exp_manager(trainer, cfg.get("exp_manager", None))
asr_model = EncDecCTCModel(cfg=cfg.model, trainer=trainer)
# Initialize the weights of the model from another model, if provided via config
state_dict = torch.load('model_weights.ckpt' , map_location='cpu')
print('load model done')
encoder_state_dict = {key[8:] : value for key,value in state_dict.items() if key.split('.')[0] == 'encoder'}
asr_model.encoder.load_state_dict(encoder_state_dict)
del state_dict
del encoder_state_dict
torch.cuda.empty_cache()
print('clear GPU memory')
asr_model.encoder.freeze()
True_class = type(asr_model.encoder.encoder[0].mconv[1])
for i in asr_model.encoder.encoder:
for j in i.mconv:
if isinstance(j, True_class):
j.requires_grad_ = True
trainer.fit(asr_model)
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
gpu = 1 if cfg.trainer.gpus != 0 else 0
test_trainer = pl.Trainer(
gpus=gpu,
precision=trainer.precision,
amp_level=trainer.accelerator_connector.amp_level,
amp_backend=cfg.trainer.get("amp_backend", "native"),
)
if asr_model.prepare_test(test_trainer):
test_trainer.test(asr_model)
if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
You are unfreezing entire encoder again with that true class thing - you only need to unfreeze the batch norm within the module. I have a PR currently under review, see it's section to unfreeze batch norm.
I learn your google.colab.notebook, skip few steps on dataset for vocabulary, because i give this information in config file.
I get this graphics
I think something is wrong, namely that it is too slow
import pytorch_lightning as pl
from omegaconf import OmegaConf
import torch
from nemo.collections.asr.models import EncDecCTCModel
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
import torch.nn as nn
def enable_bn_se(m):
if type(m) == nn.BatchNorm1d:
m.train()
for param in m.parameters():
param.requires_grad_(True)
if 'SqueezeExcite' in type(m).__name__:
m.train()
for param in m.parameters():
param.requires_grad_(True)
@hydra_runner(config_path="./", config_name="jasper_10x5dr.yaml")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
trainer = pl.Trainer(**cfg.trainer, amp_level='O1')
exp_manager(trainer, cfg.get("exp_manager", None))
asr_model = EncDecCTCModel(cfg=cfg.model, trainer=trainer)
state_dict = torch.load('model_weights.ckpt', map_location='cpu')
encoder_state_dict = {key[8:] : value for key,value in state_dict.items() if key.split('.')[0] == 'encoder'}
asr_model.encoder.load_state_dict(encoder_state_dict)
del state_dict
del encoder_state_dict
asr_model.encoder.freeze()
asr_model.encoder.apply(enable_bn_se)
logging.info("Model encoder has been frozen, and batch normalization has been unfrozen")
trainer.fit(asr_model)
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
gpu = 1 if cfg.trainer.gpus != 0 else 0
test_trainer = pl.Trainer(
gpus=gpu,
precision=trainer.precision,
amp_level=trainer.accelerator_connector.amp_level,
amp_backend=cfg.trainer.get("amp_backend", "native"),
)
if asr_model.prepare_test(test_trainer):
test_trainer.test(asr_model)
if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter
I would advise you to go over every single step by step.of the QuartzNet portion and match that to your script. I see you don't disable normalize_transcripts, don't use pretrained checkpoint from_pretrained, etc
If after all that it still doesn't train then data is insufficient or there is some other issue in labeling
Hello!
I try FineTune ASR with freeze encoder on Russian language.
I am taking the parameters of the optimizer, learning rate from the this paper https://arxiv.org/pdf/2005.04290.pdf and i set warmup_steps the value to 2000, but when training the model, the loss value is "buried", and WER is in one. What actions to take to solve this problem?
model_weights.ckpt is unpack stt_en_jasper10x5dr.nemo model code for training: