joe-siyuan-qiao / WeightStandardization

Standardizing weights to accelerate micro-batch training
546 stars 43 forks source link

It just explodes!!! #23

Open VCasecnikovs opened 4 years ago

VCasecnikovs commented 4 years ago

Hello, I've been testing WS on my dataset and on my network. I have read about std error. But even after using std = (torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1) + 1e-5). I found out that there is a problem with input exploding. When I use basic Conv2D block this problem does not exist. So, my question is. Is it possible to somehow to figure it out?

MohitLamba94 commented 3 years ago

Adding affine transformation as in https://github.com/open-mmlab/mmcv/blob/d5cbf7eed1269095bfba1a07913efbbc99d2d10b/mmcv/cnn/bricks/conv_ws.py#L54 write after standardization might help. I have not tried myself but it was originally used by Google Research to avoid NaNs and might work for you as well. Let me know if it helps.

csvance commented 10 months ago

It explodes even in forward pass because the activation values tend to be much larger when using weight standardization. This is because the weights are normalized to std=1 instead of something like gain / sqrt(fan).

Try the following for forward pass changing the gain if you use a different activation function than ReLU:

    def forward(self, x):
        weight = F.batch_norm(
            self.weight.reshape(1, self.out_channels, -1), None, None,
            training=True, momentum=0., eps=self.eps).reshape_as(self.weight)

        gain = nn.init.calculate_gain('relu')
        fan = nn.init._calculate_correct_fan(self.weight, 'fan_out')
        std = gain / fan**0.5
        weight = std*weight
        x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
        return x

This way the weights are standardized similar to He normal initialization. You still get all the benefit from weight standardization without any kind of activation scale explosion.

csvance commented 10 months ago

Just to summarize my thoughts overall using this method:

  1. People are used to training nets with He initialization but weight standardization activation scales are quite higher than those initialized with He normal. This can lead to NaN values during forward pass especially when using mixed precision. This can be solved in practice by scaling the weights according to popular normal initialization schemes like He.
  2. Common practice to initialize ResNet is set the last layer in block weights to zero to avoid exploding gradient at initialization. This is mentioned in the original ResNet paper. But with weight standardization, this is not really practical, because even if you set to zero, the weights after the first step will simply be in the same direction as the gradient in the first step. This means that every block will be full ungated at initialization so to speak, which causes exploding gradient and training instability, even if convergence appears faster at first. Aggressive gradient clipping is needed in my experience to prevent weight distributions from distorting, especially when using an adaptive optimizer like Adam.