switchablenorms / Switchable-Normalization

Code for Switchable Normalization from "Differentiable Learning-to-Normalize via Switchable Normalization", https://arxiv.org/abs/1806.10779
867 stars 132 forks source link

Switch Norm 1d for 3D tensors #19

Open adrienchaton opened 5 years ago

adrienchaton commented 5 years ago

Hello, Thank you for providing code implementation for your paper.

I am interested in trying your normalization in my current experiment which works on raw waveform and audio "style". It is thus of prime interest to adaptively modulate different feature normalizations and I hope your proposal would work good to my extent.

However, when I read your Switchable-Normalization/devkit/ops/switchable_norm.py the 1d normalization only applies to 2D tensors and the 2d normalization only applies to 4D tensors. Whereas pytorch implementations of BatchNorm1d and InstanceNorm1d applies to both 2D and 3D tensors.

If possible, how should I please apply your SwitchNorm1d to 3D tensors, as for instance the output of conv1d ?

thank you !

zqiao11 commented 1 year ago

Hi, I tried to modify the original codes to fit the case you mentioned:

class SwitchNorm1d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.997, using_moving_average=True):
        super(SwitchNorm1d, self).__init__()
        self.eps = eps
        self.momentum = momentum
        self.using_moving_average = using_moving_average
        self.weight = nn.Parameter(torch.ones(1, num_features, 1))
        self.bias = nn.Parameter(torch.zeros(1, num_features, 1))
        self.mean_weight = nn.Parameter(torch.ones(3))
        self.var_weight = nn.Parameter(torch.ones(3))
        self.register_buffer('running_mean', torch.zeros(1, num_features, 1))
        self.register_buffer('running_var', torch.zeros(1, num_features, 1))
        self.reset_parameters()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.zero_()
        self.weight.data.fill_(1)
        self.bias.data.zero_()

    def _check_input_dim(self, input):
        if input.dim() != 3:
            raise ValueError('expected 3D input (got {}D input)'
                             .format(input.dim()))

    def forward(self, x):  # (B, C, L)
        self._check_input_dim(x)
        mean_ln = x.mean(1, keepdim=True)  # (B, 1, L)
        var_ln = x.var(1, keepdim=True)

        mean_in = x.mean(-1, keepdim=True)  # (B, C, 1)
        var_in = x.var(-1, keepdim=True)
        temp = var_in + mean_in ** 2

        if self.training:
            mean_bn = mean_in.mean(0, keepdim=True)  # (1, C, 1)
            var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2

            if self.using_moving_average:
                self.running_mean.mul_(self.momentum)
                self.running_mean.add_((1 - self.momentum) * mean_bn.data)
                self.running_var.mul_(self.momentum)
                self.running_var.add_((1 - self.momentum) * var_bn.data)
            else:
                self.running_mean.add_(mean_bn.data)
                self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
        else:
            mean_bn = torch.autograd.Variable(self.running_mean)
            var_bn = torch.autograd.Variable(self.running_var)

        softmax = nn.Softmax(0)
        mean_weight = softmax(self.mean_weight)
        var_weight = softmax(self.var_weight)

        mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
        var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn

        x = (x - mean) / (var + self.eps).sqrt()
        return x * self.weight + self.bias

For temporal data, LN is slightly different from how it calculates for images (SwitchNorm2d). It should work as desired :)