Open albertaparicio opened 6 years ago
Due to the upgrading of pytorch0.4, the densenet model structure is also updated, thus the old model cannot be loaded properly. I don't have solution for that yet.
Same problem here... any solution?
I have been able to load the model as follows:
model = models.__dict__[arch](num_classes=365)
checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict,strict=False) #NOTICE THE strict=False
model.cuda()
model.eval()
But the performance is almost random...
Any idea?
Hey Guys,
I found a solution for this. As I found out, the name of the keys from the downloaded torchvision model and the trained Places365 model just don't match. I'm not the biggest python pro, so I came up with a pretty simple and barbaric solution (feel free to refactor it!).
model = models.__dict__[arch](num_classes=365)
checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
state_dict = {str.replace(k,'norm.','norm'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'conv.','conv'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'normweight','norm.weight'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'normrunning','norm.running'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'normbias','norm.bias'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'convweight','conv.weight'): v for k,v in state_dict.items()}
I am basically just replacing the key names of the Places365 model with the key names of the model downloaded from the torchvision.models package. I tested the model, and for me it works :)
@RiSaMa strict=False
ignores all not matching keys. So the initial model keeps his initial weights when no matching key is found in the new state_dict (in our case i think it's all of them). As you most likely are loading an untrained model from the torchvision.models package this results in random performance.
I have tried the
run_placesCNN_basic.py
script with thedensenet161
architecture. It works well when I choose a ResNet architecture. I have attached the error logDo you have any suggestions why this particular model fails to load?
Thank you