aladdinpersson / Machine-Learning-Collection

A resource for learning about Machine learning & Deep Learning
https://www.youtube.com/c/AladdinPersson
MIT License
7.7k stars 2.7k forks source link

about ProGAN Minibatch standard deviation #138

Closed penway closed 1 year ago

penway commented 1 year ago

I read the source code and compare the result of the two implementation and the result is different. The original one use mean and sqrt to calculate the stddev, and the corresponding result is different from torch.std. Then I code this exactly following the original implementation:

def minibatch_stddev_layer(x, group_size=4):
    group_size = min(group_size, x.shape[0])        # Minibatch must be divisible by (or smaller than) group_size.
    s = x.shape                                                       # [NCHW]  Input shape.
    y = torch.reshape(x.clone(), [group_size, -1, s[1], s[2], s[3]])  # [GMCHW] Split minibatch into M groups of size G.
    # y = torch.tensor(y, dtype=torch.float32)                     # [GMCHW] Cast to FP32. My PyTorch already run in f32
    y -= torch.mean(y, dim=0, keepdim=True)                           # [GMCHW] Subtract mean over group.
    y = torch.mean(torch.square(y), dim=0)                            # [MCHW]  Calc variance over group.
    y = torch.sqrt(y + 1e-8)                                          # [MCHW]  Calc stddev over group.
    y = torch.mean(y, dim=[1,2,3], keepdim=True)                      # [M111]  Take average over fmaps and pixels.
    # y = torch.tensor(y, dtype=x.dtype)                              # [M111]  Cast back to original data type.
    y = torch.tile(y, [group_size, 1, s[2], s[3]])                    # [N1HW]  Replicate over group and pixels.
    return torch.cat([x, y], dim=1)                                   # [NCHW]  Append as new fmap.

If it is necessary I would like to make a pull request.