jeonsworld / ViT-pytorch

Pytorch reimplementation of the Vision Transformer (An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale)
MIT License
1.9k stars 363 forks source link

Why the kernel is normalized in StdConv2d? #21

Open xychenunc opened 3 years ago

xychenunc commented 3 years ago

I noticed that you used

class StdConv2d(nn.Conv2d):

def forward(self, x):
    w = self.weight
    v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
    w = (w - m) / torch.sqrt(v + 1e-5)
    return F.conv2d(x, w, self.bias, self.stride, self.padding,
                    self.dilation, self.groups)

Why 'w' is normalized here? Any special consideration for implementing in this way? Thanks

jeonsworld commented 3 years ago

In CNN, weight standardization is suggested in Big Transfer (BiT): General Visual Representation Learning. See section 4.3 of paper.