taohan10200 / IIM

PyTorch implementations of the paper: "Learning Independent Instance Maps for Crowd Localization"
MIT License
163 stars 39 forks source link

Cannot load pretrained models #5

Closed bitwalt closed 3 years ago

bitwalt commented 3 years ago

Hi, thank you for sharing your code! This model seems very interesting and promising. I was trying to test your model on a video, but unfortunately I was not able to load your pre-trained models. I tried both HR and VGG models, but it always break on load_state_dict() Do you know why?

netName = 'HR_Net'
GPU_ID = '0'
torch.backends.cudnn.benchmark = True
model_path = './saved_model/NWPU-HR-ep_241_F1_0.802_Pre_0.841_Rec_0.766_mae_55.6_mse_330.9.pth'
net = Crowd_locator(netName,GPU_ID,pretrained=True)

File "/home/walter/anaconda3/envs/crowd/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Crowd_locator: Missing key(s) in state_dict: "Extractor.conv1.weight", "Extractor.bn1.weight", "Extractor.bn1.bias", "Extractor.bn1.running_mean", "Extractor.bn1.running_var", "Extractor.conv2.weight", "Extractor.bn2.weight", "Extractor.bn2.bias", "Extractor.bn2.running_mean", "Extractor.bn2.running_var", "Extractor.layer1.0.conv1.weight", "Extractor.layer1.0.bn1.weight", "Extractor.layer1.0.bn1.bias", "Extractor.layer1.0.bn1.running_mean", "Extractor.layer1.0.bn1.running_var", "Extractor.layer1.0.conv2.weight", "Extractor.layer1.0.bn2.weight", "Extractor.layer1.0.bn2.bias", "Extractor.layer1.0.bn2.running_mean", [...]

gjy3035 commented 3 years ago


taohan10200 commented 3 years ago

This error occurs because we saved the model with multi-GPU training and you used a single GPU to load the model. We have updated the single GPU model loading method. You can download the latest code in test.py, or modify your version by referring to the following code:

from collections import OrderedDict
net = Crowd_locator(netName,GPU_ID,pretrained=True)
state_dict = torch.load(model_path)
if len(GPU_ID.split(','))>1:
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace('module.', '')
        new_state_dict[name] = v
bitwalt commented 3 years ago

Thank you very much both, it was that! This is really a great job