NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
11.84k stars 2.46k forks source link

Quartznet: freezing encoder #536

Closed OndrejGl closed 4 years ago

OndrejGl commented 4 years ago

Hi, I would like to re-use a trained Quartznet encoder, and train the decoder on new data. After the encoder is defined, I call: encoder.freeze() However, when I run the training via torch.distributed.launch, I get: AssertionError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.

Here is the full traceback:

Traceback (most recent call last):
  File "/home/ogl/src/NeMo/examples/asr/quartznet.py", line 280, in <module>
    main()
  File "/home/ogl/src/NeMo/examples/asr/quartznet.py", line 275, in main
    synced_batchnorm_groupsize=args.synced_bn_groupsize,
  File "/home/ogl/src/NeMo/nemo/core/neural_factory.py", line 587, in train
    amp_max_loss_scale=amp_max_loss_scale,
  File "/home/ogl/src/NeMo/nemo/backends/pytorch/actions.py", line 1249, in train
    pmodule, device_ids=[self.local_rank], broadcast_buffers=False, find_unused_parameters=True
  File "/home/ogl/.local/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 238, in __init__
    "DistributedDataParallel is not needed when a module "
AssertionError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.

What I am doing wrong? Thanks

bill-kalog commented 4 years ago

Hi,

Not sure about your problem, are you sure that you have any trainable parameters? Something like this has worked fine for me so far. Here I freeze only part of the encoder graph but should be the same approach. Also don't load decoder at all because I use different vocabulary.

    encoder = nemo_asr.JasperEncoder(
        feat_in=quartz_params["AudioToMelSpectrogramPreprocessor"]["features"], **quartz_params["JasperEncoder"],
    )
    encoder.restore_from("quartznet15x5/JasperEncoder-STEP-247400.pt")

    not_freeze_list = ["encoder.17.*"]

    freeze_reg = '(?:% s)' % '|'.join(not_freeze_list)
    variables = [name for name, param in encoder.named_parameters() if not re.match(freeze_reg, name)]
    print(f"Will freeze:\n {variables}")
    encoder.freeze(variables)

    print(f"Trainable encoder params:\n {[name for name, param in encoder.named_parameters() if param.requires_grad]}")
    decoder = nemo_asr.JasperDecoderForCTC(
        feat_in=quartz_params["JasperEncoder"]["jasper"][-1]["filters"], num_classes=len(vocab),
    )

Also I have never used torch.distributed.launch maybe try to run things first without it and see if you still have similar problem

okuchaiev commented 4 years ago

This should be fixed in the latest master. However, I strongly recommend you do not freeze encoder weights while fine-tuning (we are getting pretty good results when fine-tuning without freezing)