Open adrienchaton opened 5 years ago
Hi, I tried to modify the original codes to fit the case you mentioned:
class SwitchNorm1d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.997, using_moving_average=True):
super(SwitchNorm1d, self).__init__()
self.eps = eps
self.momentum = momentum
self.using_moving_average = using_moving_average
self.weight = nn.Parameter(torch.ones(1, num_features, 1))
self.bias = nn.Parameter(torch.zeros(1, num_features, 1))
self.mean_weight = nn.Parameter(torch.ones(3))
self.var_weight = nn.Parameter(torch.ones(3))
self.register_buffer('running_mean', torch.zeros(1, num_features, 1))
self.register_buffer('running_var', torch.zeros(1, num_features, 1))
self.reset_parameters()
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.zero_()
self.weight.data.fill_(1)
self.bias.data.zero_()
def _check_input_dim(self, input):
if input.dim() != 3:
raise ValueError('expected 3D input (got {}D input)'
.format(input.dim()))
def forward(self, x): # (B, C, L)
self._check_input_dim(x)
mean_ln = x.mean(1, keepdim=True) # (B, 1, L)
var_ln = x.var(1, keepdim=True)
mean_in = x.mean(-1, keepdim=True) # (B, C, 1)
var_in = x.var(-1, keepdim=True)
temp = var_in + mean_in ** 2
if self.training:
mean_bn = mean_in.mean(0, keepdim=True) # (1, C, 1)
var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2
if self.using_moving_average:
self.running_mean.mul_(self.momentum)
self.running_mean.add_((1 - self.momentum) * mean_bn.data)
self.running_var.mul_(self.momentum)
self.running_var.add_((1 - self.momentum) * var_bn.data)
else:
self.running_mean.add_(mean_bn.data)
self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
else:
mean_bn = torch.autograd.Variable(self.running_mean)
var_bn = torch.autograd.Variable(self.running_var)
softmax = nn.Softmax(0)
mean_weight = softmax(self.mean_weight)
var_weight = softmax(self.var_weight)
mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn
x = (x - mean) / (var + self.eps).sqrt()
return x * self.weight + self.bias
For temporal data, LN is slightly different from how it calculates for images (SwitchNorm2d). It should work as desired :)
Hello, Thank you for providing code implementation for your paper.
I am interested in trying your normalization in my current experiment which works on raw waveform and audio "style". It is thus of prime interest to adaptively modulate different feature normalizations and I hope your proposal would work good to my extent.
However, when I read your Switchable-Normalization/devkit/ops/switchable_norm.py the 1d normalization only applies to 2D tensors and the 2d normalization only applies to 4D tensors. Whereas pytorch implementations of BatchNorm1d and InstanceNorm1d applies to both 2D and 3D tensors.
If possible, how should I please apply your SwitchNorm1d to 3D tensors, as for instance the output of conv1d ?
thank you !