Regarding the issues above, I think the following lines are the proper way to replace the original implementation with the group normalization.
The second function is the slightly modified official implementation of SGE block to align with the GN in pytorch
self.gn = nn.GroupNorm(1, 1)
def forward(self, x):
b, c, h, w = x.size()
x = x.view(b * self.groups, -1, h, w)
xn = x * self.avg_pool(x)
xn = xn.sum(dim=1, keepdim=True)
xn = xn.view(b * self.groups, -1, h, w)
t = self.gn.forward(xn)
x = x * self.sig(t.view(b * self.groups, 1, h, w))
x = x.view(b, c, h, w)
return x
def oforward(self, x):
b, c, h, w = x.size()
x = x.view(b * self.groups, -1, h, w)
xn = x * self.avg_pool(x)
# Reduce the weighted channels in each groups to obtain the attention maps for each groups
# (This operation is not performed in GN)
xn = xn.sum(dim=1, keepdim=True)
# Flatten the spatial in each groups
t = xn.view(b * self.groups, -1)
# I think we should use the std of the original t instead of the one updated by subtracting a mean from it.
var = t.var(dim=1, keepdim=True, unbiased=False)
t = (t - t.mean(dim=1, keepdim=True)) / torch.sqrt(var + self.eps)
t = t.view(b, self.groups, h, w)
t = t * self.weight + self.bias
t = t.view(b * self.groups, 1, h, w)
x = x * self.sig(t)
x = x.view(b, c, h, w)
return x
Following is the testing code with the result:4.3839216232299807e-07
running_sum = 0
for _ in range(100):
t = torch.rand(32, 512, 21, 21)
m = SGE(64, 512) # number of groups and input channels
running_sum += (m.forward(t) - m.oforward(t)).max().item()
print("The average maximum difference between the tensor is : ", running_sum / 100)
2 #33
Regarding the issues above, I think the following lines are the proper way to replace the original implementation with the group normalization. The second function is the slightly modified official implementation of SGE block to align with the GN in pytorch
Following is the testing code with the result:4.3839216232299807e-07