NVlabs / FUNIT

Translate images to unseen domains in the test time with few example images.
1.57k stars 238 forks source link

Why batch_norm is used inside AdaptiveInstanceNorm2d? #31

Open leitro opened 4 years ago

leitro commented 4 years ago

Hi! I have a doubt that the code in blocks.py (L188-L192) as show below:

class AdaptiveInstanceNorm2d(nn.Module):
        x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
        out = F.batch_norm(
            x_reshaped, running_mean, running_var, self.weight, self.bias,
            True, self.momentum, self.eps)
        return out.view(b, c, *x.size()[2:])

It is the definition of adaptive instance normalization. It looks like you are trying to reshape a batch of images into a "bigger" single batch image, then apply "batch normalization" on it, finally recover it back to batch, channel, height, weight. But, no matter reshaping it into single batch or not, the features of each channel from all the batch have been normalized. I am wondering how it could be an instance normalization.

I believe the code is perfectly correct, but please explain the tricks that were used here, thanks in advance!

iperov commented 4 years ago

this is why I don't like pytorch. :D

pomelyu commented 4 years ago

There is a slightly difference between instance norm and adaptive instance norm.

In the instance norm, the data would be normalized on whole image for separate channels, hence the shapes of weight and bias would be both (num_channels). However, in adaptive instance norm the shapes of weight and bias should be (batch_size * num_channels), since each sample has different modulation from the corresponding latent.

That is why the code reshape x to (1, batch_size * num_channels, H, W) and then use F.batch_norm to apply the modulation on each sample and each channel instead of using F.instance_norm.