google-research / simclr

SimCLRv2 - Big Self-Supervised Models are Strong Semi-Supervised Learners
https://arxiv.org/abs/2006.10029
Apache License 2.0
4.05k stars 624 forks source link

Modifications to ResNet-50 #137

Open DianCh opened 3 years ago

DianCh commented 3 years ago

Hi, I loaded the r50_1x_sk0 checkpoint for ResNet-50, and it seems that the network architecture isn't exactly the original one in ResNet paper. However, I didn't see this is mentioned in the SimCLRv2 paper. Am I missing anything? If not, why did you modify the ResNet in this way?

Thank you!

(Just posting the printed architecture difference here)

Original: torch.Size([64, 64, 1, 1]), SimCLRv2: torch.Size([256, 64, 1, 1]) - net.1.blocks.0.projection.shortcut.weight
Original: torch.Size([64]), SimCLRv2: torch.Size([256]) - net.1.blocks.0.projection.bn.0.weight
Original: torch.Size([64]), SimCLRv2: torch.Size([256]) - net.1.blocks.0.projection.bn.0.bias
Original: torch.Size([64]), SimCLRv2: torch.Size([256]) - net.1.blocks.0.projection.bn.0.running_mean
Original: torch.Size([64]), SimCLRv2: torch.Size([256]) - net.1.blocks.0.projection.bn.0.running_var
Original: torch.Size([64, 64, 3, 3]), SimCLRv2: torch.Size([64, 64, 1, 1]) - net.1.blocks.0.net.0.weight
Original: torch.Size([256, 64, 1, 1]), SimCLRv2: torch.Size([64, 64, 3, 3]) - net.1.blocks.0.net.2.weight
Original: torch.Size([256]), SimCLRv2: torch.Size([64]) - net.1.blocks.0.net.3.0.weight
Original: torch.Size([256]), SimCLRv2: torch.Size([64]) - net.1.blocks.0.net.3.0.bias
Original: torch.Size([256]), SimCLRv2: torch.Size([64]) - net.1.blocks.0.net.3.0.running_mean
Original: torch.Size([256]), SimCLRv2: torch.Size([64]) - net.1.blocks.0.net.3.0.running_var
Original: torch.Size([128, 256, 1, 1]), SimCLRv2: torch.Size([512, 256, 1, 1]) - net.2.blocks.0.projection.shortcut.weight
Original: torch.Size([128]), SimCLRv2: torch.Size([512]) - net.2.blocks.0.projection.bn.0.weight
Original: torch.Size([128]), SimCLRv2: torch.Size([512]) - net.2.blocks.0.projection.bn.0.bias
Original: torch.Size([128]), SimCLRv2: torch.Size([512]) - net.2.blocks.0.projection.bn.0.running_mean
Original: torch.Size([128]), SimCLRv2: torch.Size([512]) - net.2.blocks.0.projection.bn.0.running_var
Original: torch.Size([128, 128, 3, 3]), SimCLRv2: torch.Size([128, 256, 1, 1]) - net.2.blocks.0.net.0.weight
Original: torch.Size([512, 128, 1, 1]), SimCLRv2: torch.Size([128, 128, 3, 3]) - net.2.blocks.0.net.2.weight
Original: torch.Size([512]), SimCLRv2: torch.Size([128]) - net.2.blocks.0.net.3.0.weight
Original: torch.Size([512]), SimCLRv2: torch.Size([128]) - net.2.blocks.0.net.3.0.bias
Original: torch.Size([512]), SimCLRv2: torch.Size([128]) - net.2.blocks.0.net.3.0.running_mean
Original: torch.Size([512]), SimCLRv2: torch.Size([128]) - net.2.blocks.0.net.3.0.running_var
Original: torch.Size([512, 256, 1, 1]), SimCLRv2: torch.Size([512, 128, 1, 1]) - net.2.blocks.0.net.4.weight
Original: torch.Size([256, 512, 1, 1]), SimCLRv2: torch.Size([1024, 512, 1, 1]) - net.3.blocks.0.projection.shortcut.weight
Original: torch.Size([256]), SimCLRv2: torch.Size([1024]) - net.3.blocks.0.projection.bn.0.weight
Original: torch.Size([256]), SimCLRv2: torch.Size([1024]) - net.3.blocks.0.projection.bn.0.bias
Original: torch.Size([256]), SimCLRv2: torch.Size([1024]) - net.3.blocks.0.projection.bn.0.running_mean
Original: torch.Size([256]), SimCLRv2: torch.Size([1024]) - net.3.blocks.0.projection.bn.0.running_var
Original: torch.Size([256, 256, 3, 3]), SimCLRv2: torch.Size([256, 512, 1, 1]) - net.3.blocks.0.net.0.weight
Original: torch.Size([1024, 256, 1, 1]), SimCLRv2: torch.Size([256, 256, 3, 3]) - net.3.blocks.0.net.2.weight
Original: torch.Size([1024]), SimCLRv2: torch.Size([256]) - net.3.blocks.0.net.3.0.weight
Original: torch.Size([1024]), SimCLRv2: torch.Size([256]) - net.3.blocks.0.net.3.0.bias
Original: torch.Size([1024]), SimCLRv2: torch.Size([256]) - net.3.blocks.0.net.3.0.running_mean
Original: torch.Size([1024]), SimCLRv2: torch.Size([256]) - net.3.blocks.0.net.3.0.running_var
Original: torch.Size([1024, 512, 1, 1]), SimCLRv2: torch.Size([1024, 256, 1, 1]) - net.3.blocks.0.net.4.weight
Original: torch.Size([512, 1024, 1, 1]), SimCLRv2: torch.Size([2048, 1024, 1, 1]) - net.4.blocks.0.projection.shortcut.weight
Original: torch.Size([512]), SimCLRv2: torch.Size([2048]) - net.4.blocks.0.projection.bn.0.weight
Original: torch.Size([512]), SimCLRv2: torch.Size([2048]) - net.4.blocks.0.projection.bn.0.bias
Original: torch.Size([512]), SimCLRv2: torch.Size([2048]) - net.4.blocks.0.projection.bn.0.running_mean
Original: torch.Size([512]), SimCLRv2: torch.Size([2048]) - net.4.blocks.0.projection.bn.0.running_var
Original: torch.Size([512, 512, 3, 3]), SimCLRv2: torch.Size([512, 1024, 1, 1]) - net.4.blocks.0.net.0.weight
Original: torch.Size([2048, 512, 1, 1]), SimCLRv2: torch.Size([512, 512, 3, 3]) - net.4.blocks.0.net.2.weight
Original: torch.Size([2048]), SimCLRv2: torch.Size([512]) - net.4.blocks.0.net.3.0.weight
Original: torch.Size([2048]), SimCLRv2: torch.Size([512]) - net.4.blocks.0.net.3.0.bias
Original: torch.Size([2048]), SimCLRv2: torch.Size([512]) - net.4.blocks.0.net.3.0.running_mean
Original: torch.Size([2048]), SimCLRv2: torch.Size([512]) - net.4.blocks.0.net.3.0.running_var
Original: torch.Size([2048, 1024, 1, 1]), SimCLRv2: torch.Size([2048, 512, 1, 1]) - net.4.blocks.0.net.4.weight
chentingpc commented 3 years ago

we follow most common resnet-50 implementation (e.g. tensorflow, pytorch). There is a small change to original resnet for all these implementations (see ResNet-B in https://arxiv.org/pdf/1812.01187.pdf for explanation), but that shouldn't affect param count.