Open LightingMc opened 8 months ago
I have this problem too. Thanks for tip ! I wonder if there is difference between author's resnet and pytorch's default resnet hope the performance is the same
Hi,
Sorry for the confusion. The resnet (nn.Module file) used in this repo was only for CIFAR input, i.e., 32x32. The weights for ImageNet we provided here is for input of 224x224, so it can only be loaded with pytorch official definition of ResNet, which takes 224x224 as input.
Historically, we firstly release this repo for CIFAR-10/100, so we define resnet for only 32x32 input. Later on, I trained a SupCon ImageNet model with my other code base, and shared the weights in this repo. So it caused this confusion.
@HobbitLong thats what I thought as well hahaha. This code was for cifar but the weights were for imagenet.
@HobbitLong thats what I thought as well hahaha. This code was for cifar but the weights were for imagenet.
Glad you figured it out much earlier, and thank you for sharing it!
I thought would be helpful for other people. I had issues with getting the resnet used in this repo running properly, but the given weights work well with Pytorch's default resnet.
Loading weights
state_dict=torch.load("supcon_official.pth",'cpu')
Correcting the terms properly.
state_dict=state_dict['model'] new_state_dict = {} for k, v in state_dict.items(): k = k.replace("module.", "") new_state_dict[k] = v state_dict = new_state_dict new_state_dict = {}
for k, v in state_dict.items(): k = k.replace("encoder.", "") new_state_dict[k] = v state_dict = new_state_dict
Using the standard pytorch resnet50
model = resnet50() del model.fc model.fc = nn.Identity()
Dont need this
state_dict.pop("head.0.weight", None) state_dict.pop("head.0.bias", None) state_dict.pop("head.2.weight", None) state_dict.pop("head.2.bias", None)
This should do the trick
model.load_state_dict(state_dict,strict=True)