I read the source code and compare the result of the two implementation and the result is different. The original one use mean and sqrt to calculate the stddev, and the corresponding result is different from torch.std. Then I code this exactly following the original implementation:
def minibatch_stddev_layer(x, group_size=4):
group_size = min(group_size, x.shape[0]) # Minibatch must be divisible by (or smaller than) group_size.
s = x.shape # [NCHW] Input shape.
y = torch.reshape(x.clone(), [group_size, -1, s[1], s[2], s[3]]) # [GMCHW] Split minibatch into M groups of size G.
# y = torch.tensor(y, dtype=torch.float32) # [GMCHW] Cast to FP32. My PyTorch already run in f32
y -= torch.mean(y, dim=0, keepdim=True) # [GMCHW] Subtract mean over group.
y = torch.mean(torch.square(y), dim=0) # [MCHW] Calc variance over group.
y = torch.sqrt(y + 1e-8) # [MCHW] Calc stddev over group.
y = torch.mean(y, dim=[1,2,3], keepdim=True) # [M111] Take average over fmaps and pixels.
# y = torch.tensor(y, dtype=x.dtype) # [M111] Cast back to original data type.
y = torch.tile(y, [group_size, 1, s[2], s[3]]) # [N1HW] Replicate over group and pixels.
return torch.cat([x, y], dim=1) # [NCHW] Append as new fmap.
If it is necessary I would like to make a pull request.
I read the source code and compare the result of the two implementation and the result is different. The original one use
mean
andsqrt
to calculate the stddev, and the corresponding result is different fromtorch.std
. Then I code this exactly following the original implementation:If it is necessary I would like to make a pull request.