rll / deepul

766 stars 374 forks source link

ActNorm implementation missing division by `std` on the shift parameter #6

Open hemildesai opened 4 years ago

hemildesai commented 4 years ago

Hi,

Thanks for making the video lectures and homework public. I'm really enjoying the course so far. I was going through homework 2 and wanted to compare my stuff with the solutions. For the solution of hw2, I found the following implementation of ActNorm

class ActNorm(nn.Module):
    def __init__(self, n_channels):
        super(ActNorm, self).__init__()
        self.log_scale = nn.Parameter(torch.zeros(1, n_channels, 1, 1), requires_grad=True)
        self.shift = nn.Parameter(torch.zeros(1, n_channels, 1, 1), requires_grad=True)
        self.n_channels = n_channels
        self.initialized = False

    def forward(self, x, reverse=False):
        if reverse:
            return (x - self.shift) * torch.exp(-self.log_scale), self.log_scale
        else:
            if not self.initialized:
                self.shift.data = -torch.mean(x, dim=[0, 2, 3], keepdim=True)
                self.log_scale.data = - torch.log(
                    torch.std(x.permute(1, 0, 2, 3).reshape(self.n_channels, -1), dim=1).reshape(1, self.n_channels, 1,
                                                                                                 1))
                self.initialized = True
                result = x * torch.exp(self.log_scale) + self.shift
            return x * torch.exp(self.log_scale) + self.shift, self.log_scale

I think the shift needs to be divided by the standard deviation as follows for the activations to be normalized.

self.shift.data = -(torch.mean(x, dim=[0, 2, 3], keepdim=True) * torch.exp(self.log_scale)

Let me know if I'm missing something.