bearpaw / pytorch-classification

Classification with PyTorch.
MIT License
1.68k stars 562 forks source link

The pretrain cifar10 resnet110 indeed is resnet164 (BottleNeck) #36

Open jiangyangzhou opened 5 years ago

jiangyangzhou commented 5 years ago

I found the pretrain cifar10 model resnet110 is not resnet110, but resnet164. The model is: model = resnet(depth = 164, block_name='bottleNeck') Use this model can load the state_dict sucessfully, but I haven't check the accuracy. btw, the state_dict contain 'module', we can load the state_dict like this:

def load_parallel_weight(model, weight):
    state_dict = torch.load(weight)['state_dict']
    new_dict={}
    for w in state_dict:
        new_dict['.'.join(filter(lambda x:x!="module", w.split('.')))] = state_dict[w]
    model.load_state_dict(new_dict)
jiaqian commented 4 years ago

ahha, that's why the dictionary keys don't match. could you please upload the pre-trained model for resnet 110 with block_name='basicblock'?