CSAILVision / places365

The Places365-CNNs for Scene Classification
http://places2.csail.mit.edu/
MIT License
1.9k stars 534 forks source link

Loading DenseNet161 model does not work #53

Open albertaparicio opened 6 years ago

albertaparicio commented 6 years ago

I have tried the run_placesCNN_basic.py script with the densenet161 architecture. It works well when I choose a ResNet architecture. I have attached the error log

Do you have any suggestions why this particular model fails to load?

Thank you

zhoubolei commented 6 years ago

Due to the upgrading of pytorch0.4, the densenet model structure is also updated, thus the old model cannot be loaded properly. I don't have solution for that yet.

RiSaMa commented 5 years ago

Same problem here... any solution?

I have been able to load the model as follows:

    model = models.__dict__[arch](num_classes=365)
    checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
    state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict,strict=False) #NOTICE THE strict=False
    model.cuda()
    model.eval()

But the performance is almost random...

Any idea?

xeroxM commented 4 years ago

Hey Guys,

I found a solution for this. As I found out, the name of the keys from the downloaded torchvision model and the trained Places365 model just don't match. I'm not the biggest python pro, so I came up with a pretty simple and barbaric solution (feel free to refactor it!).

model = models.__dict__[arch](num_classes=365)
checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)

state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
state_dict = {str.replace(k,'norm.','norm'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'conv.','conv'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'normweight','norm.weight'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'normrunning','norm.running'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'normbias','norm.bias'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'convweight','conv.weight'): v for k,v in state_dict.items()}

I am basically just replacing the key names of the Places365 model with the key names of the model downloaded from the torchvision.models package. I tested the model, and for me it works :)

@RiSaMa strict=False ignores all not matching keys. So the initial model keeps his initial weights when no matching key is found in the new state_dict (in our case i think it's all of them). As you most likely are loading an untrained model from the torchvision.models package this results in random performance.