HobbitLong / SupContrast

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)
BSD 2-Clause "Simplified" License
3.08k stars 528 forks source link

Code for using the ImageNet pretrained model #146

Open LightingMc opened 8 months ago

LightingMc commented 8 months ago

I thought would be helpful for other people. I had issues with getting the resnet used in this repo running properly, but the given weights work well with Pytorch's default resnet.

Loading weights

state_dict=torch.load("supcon_official.pth",'cpu')

Correcting the terms properly.

state_dict=state_dict['model'] new_state_dict = {} for k, v in state_dict.items(): k = k.replace("module.", "") new_state_dict[k] = v state_dict = new_state_dict new_state_dict = {}

for k, v in state_dict.items(): k = k.replace("encoder.", "") new_state_dict[k] = v state_dict = new_state_dict

Using the standard pytorch resnet50

model = resnet50() del model.fc model.fc = nn.Identity()

Dont need this

state_dict.pop("head.0.weight", None) state_dict.pop("head.0.bias", None) state_dict.pop("head.2.weight", None) state_dict.pop("head.2.bias", None)

This should do the trick

model.load_state_dict(state_dict,strict=True)

DruncBread commented 1 month ago

I have this problem too. Thanks for tip ! I wonder if there is difference between author's resnet and pytorch's default resnet hope the performance is the same

HobbitLong commented 1 month ago

Hi,

Sorry for the confusion. The resnet (nn.Module file) used in this repo was only for CIFAR input, i.e., 32x32. The weights for ImageNet we provided here is for input of 224x224, so it can only be loaded with pytorch official definition of ResNet, which takes 224x224 as input.

Historically, we firstly release this repo for CIFAR-10/100, so we define resnet for only 32x32 input. Later on, I trained a SupCon ImageNet model with my other code base, and shared the weights in this repo. So it caused this confusion.

LightingMc commented 1 month ago

@HobbitLong thats what I thought as well hahaha. This code was for cifar but the weights were for imagenet.

HobbitLong commented 1 month ago

@HobbitLong thats what I thought as well hahaha. This code was for cifar but the weights were for imagenet.

Glad you figured it out much earlier, and thank you for sharing it!