akamaster / pytorch_resnet_cifar10

Proper implementation of ResNet-s for CIFAR10/100 in pytorch that matches description of the original paper.
BSD 2-Clause "Simplified" License
1.22k stars 335 forks source link

Issues on loading pre-trained model #3

Closed DeeperCS closed 6 years ago

DeeperCS commented 6 years ago

Thank you for providing those very useful pre-trained models. However, I got some troubles when loading them. What I did are listed as follows,

res20 = resnet20()
weights = torch.load('pytorch_resnet_cifar10/pretrained_models/resnet20.th')
res20.load_state_dict(weights)

It fails because the keys are not matching, e.g., "conv1.weight" in the constructed model while "module.conv1.weight" in the pre-trained weights.

So I'm wondering is it possible to provide an example code for loading the pre-trained model? Or how can I solve this problem? Thanks.

DeeperCS commented 6 years ago

Just got it solved, the following code works for me.


model = torch.nn.DataParallel(resnet20())
model.cuda()

checkpoint = torch.load('pytorch_resnet_cifar10/pretrained_models/resnet20.th')
model.load_state_dict(checkpoint['state_dict'])