NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
11.84k stars 2.46k forks source link

Possible problems with weight initialization in NeMo ASR[Question] #1462

Closed rbracco closed 3 years ago

rbracco commented 3 years ago

Describe your question

  1. Weight initialization in the ConvASREncoder and ConvASRDecoder defaults to xavier_uniform but the architectures use ReLU which does best with kaiming initialization. Why was xavier initialization chosen?
  2. ConvASREncoder and ConvASRDecoder have an init_mode argument that delegates to nemo.collections.asr.parts.jasper.init_weights which returns different results than PyTorch's nn.init (the weights have a different initial standard deviation) and results in significantly worse training during transfer learning in my experiments. Why aren't PyTorch defaults used?

Experimental Results**

I tried transfer learning from quartznet to a dataset with a different vocab, experimenting with 1 or 2 linear layers in my ASR decoder (decoder layer code included at end of post). I tried initializing decoder weights using NeMo's xavier_uniform, NeMo's kaiming_uniform and pytorch defaults (kaiming uniform is the default for 1d convs). I ran 12 trials for each. 6 were 2 epochs, and 6 were 1 epoch. LR=1e-3 (0.001). Mean loss after 1ep, and 2ep is included below.

PyTorch Kaiming Uniform, mean loss 1ep=289.9, 2ep=179.9

pytorch_default(kaiming_uniform)

NeMo Kaiming Uniform: mean loss 1ep=515.5, 2ep=581.6

kaiming_uniform

NeMo Xavier Uniform (default): mean loss 1ep=566.8, 2ep=408.0

xavier_uniform

Standard Deviation of weights after initialization

2 Layer Decoder:
image

1 Layer Decoder:
image

Environment Details

Colab Pip Install - pip install nemo-toolkit[all]==1.0.0b1

Python 3.6.9 Pytorch 1.7 OS: Ubuntu 18.04.5 LTS

Additional Details

Definition of decoders

2 Layer Decoder

N_HIDDEN = 256
self.decoder_layers = torch.nn.Sequential(
              torch.nn.Conv1d(self._feat_in, N_HIDDEN, kernel_size=1, bias=True),
              torch.nn.ReLU(),
              torch.nn.Conv1d(N_HIDDEN, self._num_classes, kernel_size=1, bias=True),
          )

1 Layer Decoder

self.decoder_layers = torch.nn.Sequential(
              torch.nn.Conv1d(self._feat_in, self._num_classes, kernel_size=1, bias=True),
          )
titu1994 commented 3 years ago

This is a good observation. I'll run some full scale experiments regarding this and respond in a day or two, if you don't mind. For clarification, when you say "pytorch default" - you mean you do not use the the weight initialization method at all and just use the default initialized weights that pytorch provides for Convs, yes?

rbracco commented 3 years ago

Thanks for taking the time to look into it. Yes that's what I mean by pytorch default. It can be done either by commenting out self.apply(lambda x: init_weights(x, mode=init_mode)) in the decoder, or overwriting it manually with quartznet.decoder.decoder_layers[0] = nn.Conv1d(1024, <N_CLASSES>, kernel_size=1, stride=1).

The default init for a 1D conv in PyTorch is kaiming uniform, but I get a different standard deviation when using the PyTorch version and the NeMo version but I didn't dig too deep on why that might be.

titu1994 commented 3 years ago

I think I might have an idea as to why applying the default is different as compared to applying kaiming_uniform

This is the default implementation of pytorch for all convND

    def reset_parameters(self) -> None:
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

Note the a=sqrt(5) as the param and the default nonlinearity value of leaky_relu.

For the kaiming_uniform mode in nemo - we compute the gain using the relu activation - as expected, if we dive deeper into what this gain value is actually computed into it can be found as

nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")  

which resolves to a=0 and different nonlinearity

def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
    ...
    fan = _calculate_correct_fan(tensor, mode)
    gain = calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
    with torch.no_grad():
        return tensor.uniform_(-bound, bound)

Herein lies the difference in the gain computation

def calculate_gain(nonlinearity, param=None):
    ...
    linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
    if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
        return 1
    elif nonlinearity == 'tanh':
        return 5.0 / 3
    elif nonlinearity == 'relu':
        return math.sqrt(2.0)
    elif nonlinearity == 'leaky_relu':
        if param is None:
            negative_slope = 0.01
        elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
            # True/False are instances of int, hence check above
            negative_slope = param
        else:
            raise ValueError("negative_slope {} not a valid number".format(param))
        return math.sqrt(2.0 / (1 + negative_slope ** 2))
    else:
        raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))

Now, lets manually compute the output of the compute_gain method for default and kaiming_uniform init_mode in nemo

default = compute_gain('leaky_relu', param=sqrt(5)) = sqrt(2.0 / (1. + 5.)) = sqrt(1./3.) kaiming_uniform = compute_gain('relu', param=0) = sqrt(2.0)

This is the reason the value of default does not match kaiming_uniform.

rbracco commented 3 years ago

Very interesting! Thanks for looking into it. I am going to dive deeper and try to figure out why my model is performing better using pytorch defaults (it could still be chance) and I'll make sure to report back.

Also, if the decoder is a single layer and doesn't have an activation function (I guess softmax is the activation function), and our encoder is frozen, why do we need to init using kaiming? Shouldn't we just init to have a mean of 0 and unitary std dev?

titu1994 commented 3 years ago

Hey @rbracco, I have some preliminary results (at least for from-scratch training).

Loss Curves

Though i havent plotted it here, the WERs (train, dev, test) exactly match the shape of the graph here. While this is just a point sample observation, xavier for both enc-dev is the best bet for from scratch training.

This doesnt invalidate your observation that loss reduces faster for finetuning. For the time being, I think we can enable a None flag for init_mode which would enable default pytorch initialization since it is a potential use case. We'll need extensive experimentation to show that its worth doing that for the decoder during finetuning however (which I currently can't do).

titu1994 commented 3 years ago

Related PR https://github.com/NVIDIA/NeMo/pull/1472

rbracco commented 3 years ago

Awesome work @titu1994! Good to see that it isn't a problem for training from scratch. I'm not sure but maybe batchnorm lessens the importance of init since they all seem to end up in the same place. I will keep experimenting with transfer learning and report back. I just switched my training from English to Spanish with a totally different dataset and vocab, so I will try several inits on the new set and see if it is similar to what I experienced before, or if it was just a fluke. I should be able to report back early next week.

titu1994 commented 3 years ago

PR https://github.com/NVIDIA/NeMo/pull/1472 is merged, and therefore you can now simply pass init_mode=None to get default pytorch initialization.

rbracco commented 3 years ago

Thank you! Just wanted to report back that in Spanish I failed to converge and then switched the init to be pytorch default and converged. I haven't had time to run full experiments, but I will probably do a writeup on transfer learning with NeMo at some point as I'm discovering lots of stuff that is causing faster convergence.