implus / PytorchInsight

a pytorch lib with state-of-the-art architectures, pretrained models and real-time updated results
858 stars 123 forks source link

Replace with Group Normalization #40

Open Haus226 opened 1 month ago

Haus226 commented 1 month ago

2 #33

Regarding the issues above, I think the following lines are the proper way to replace the original implementation with the group normalization. The second function is the slightly modified official implementation of SGE block to align with the GN in pytorch

self.gn = nn.GroupNorm(1, 1)
def forward(self, x):
      b, c, h, w = x.size()
      x = x.view(b * self.groups, -1, h, w) 
      xn = x * self.avg_pool(x)
      xn = xn.sum(dim=1, keepdim=True)
      xn = xn.view(b * self.groups, -1, h, w)
      t = self.gn.forward(xn)
      x = x * self.sig(t.view(b * self.groups, 1, h, w))
      x = x.view(b, c, h, w)
      return x

def oforward(self, x):
        b, c, h, w = x.size()
        x = x.view(b * self.groups, -1, h, w) 
        xn = x * self.avg_pool(x)
        # Reduce the weighted channels in each groups to obtain the attention maps for each groups
        # (This operation is not performed in GN)
        xn = xn.sum(dim=1, keepdim=True)
        # Flatten the spatial in each groups
        t = xn.view(b * self.groups, -1)
        # I think we should use the std of the original t instead of the one updated by subtracting a mean from it.
        var = t.var(dim=1, keepdim=True, unbiased=False)
        t = (t - t.mean(dim=1, keepdim=True)) / torch.sqrt(var + self.eps)
        t = t.view(b, self.groups, h, w)
        t = t * self.weight + self.bias
        t = t.view(b * self.groups, 1, h, w)
        x = x * self.sig(t)
        x = x.view(b, c, h, w)
        return x

Following is the testing code with the result:4.3839216232299807e-07

running_sum = 0
for _ in range(100):
    t = torch.rand(32, 512, 21, 21)
    m = SGE(64, 512) # number of groups and input channels
    running_sum += (m.forward(t) - m.oforward(t)).max().item()
print("The average maximum difference between the tensor is : ", running_sum / 100)