Open VCasecnikovs opened 4 years ago
Adding affine transformation as in https://github.com/open-mmlab/mmcv/blob/d5cbf7eed1269095bfba1a07913efbbc99d2d10b/mmcv/cnn/bricks/conv_ws.py#L54 write after standardization might help. I have not tried myself but it was originally used by Google Research to avoid NaNs and might work for you as well. Let me know if it helps.
It explodes even in forward pass because the activation values tend to be much larger when using weight standardization. This is because the weights are normalized to std=1 instead of something like gain / sqrt(fan)
.
Try the following for forward pass changing the gain if you use a different activation function than ReLU:
def forward(self, x):
weight = F.batch_norm(
self.weight.reshape(1, self.out_channels, -1), None, None,
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
gain = nn.init.calculate_gain('relu')
fan = nn.init._calculate_correct_fan(self.weight, 'fan_out')
std = gain / fan**0.5
weight = std*weight
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
This way the weights are standardized similar to He normal initialization. You still get all the benefit from weight standardization without any kind of activation scale explosion.
Just to summarize my thoughts overall using this method:
Hello, I've been testing WS on my dataset and on my network. I have read about std error. But even after using std = (torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1) + 1e-5). I found out that there is a problem with input exploding. When I use basic Conv2D block this problem does not exist. So, my question is. Is it possible to somehow to figure it out?