NVlabs / AL-MDN

Official pytorch implementation of Active Learning for deep object detection via probabilistic modeling (ICCV 2021)
https://openaccess.thecvf.com/content/ICCV2021/html/Choi_Active_Learning_for_Deep_Object_Detection_via_Probabilistic_Modeling_ICCV_2021_paper.html
Other
167 stars 23 forks source link

RuntimeError: Error(s) in loading state_dict for DataParallel: #8

Closed ayennam closed 2 years ago

ayennam commented 2 years ago

python eval_coco.py --dataset_root /coco --trained_model weights/vgg16reducedfc.pth 81 :20: UserWarning: nn.init.constant is now deprecated in favor of nn.init.constant. init.constant(self.weight,self.gamma) True Loading weight: weights/vgg16_reducedfc.pth odict_keys(['0.weight', '0.bias', '2.weight', '2.bias', '5.weight', '5.bias', '7.weight', '7.bias', '10.weight', '10.bias', '12.weight', '12.bias', '14.weight', '14.bias', '17.weight', '17.bias', '19.weight', '19.bias', '21.weight', '21.bias', '24.weight', '24.bias', '26.weight', '26.bias', '28.weight', '28.bias', '31.weight', '31.bias', '33.weight', '33.bias']) Traceback (most recent call last): File "eval_coco.py", line 188, in net.load_state_dict(ckp['weight'] if 'weight' in ckp.keys() else ckp) File "/lib/python3.7/site-packages/torch/nn/modules/module.py", line 830, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: "module.vgg.0.weight", "module.vgg.0.bias", "module.vgg.2.weight", "module.vgg.2.bias", "module.vgg.5.weight", "module.vgg.5.bias", "module.vgg.7.weight", "module.vgg.7.bias", "module.vgg.10.weight", "module.vgg.10.bias", "module.vgg.12.weight", "module.vgg.12.bias", "module.vgg.14.weight", "module.vgg.14.bias", "module.vgg.17.weight", "module.vgg.17.bias", "module.vgg.19.weight", "module.vgg.19.bias", "module.vgg.21.weight", "module.vgg.21.bias", "module.vgg.24.weight", "module.vgg.24.bias", "module.vgg.26.weight", "module.vgg.26.bias", "module.vgg.28.weight", "module.vgg.28.bias", "module.vgg.31.weight", "module.vgg.31.bias", "module.vgg.33.weight", "module.vgg.33.bias", "module.L2Norm.weight", "module.extras.0.weight", "module.extras.0.bias", "module.extras.1.weight", "module.extras.1.bias", "module.extras.2.weight", "module.extras.2.bias", "module.extras.3.weight", "module.extras.3.bias", "module.extras.4.weight", "module.extras.4.bias", "module.extras.5.weight", "module.extras.5.bias", "module.extras.6.weight", "module.extras.6.bias", "module.extras.7.weight", "module.extras.7.bias", "module.loc_mu_1.0.weight", "module.loc_mu_1.0.bias", "module.loc_mu_1.1.weight", "module.loc_mu_1.1.bias", "module.loc_mu_1.2.weight", "module.loc_mu_1.2.bias", "module.loc_mu_1.3.weight", "module.loc_mu_1.3.bias", "module.loc_mu_1.4.weight", "module.loc_mu_1.4.bias", "module.loc_mu_1.5.weight", "module.loc_mu_1.5.bias", "module.loc_var_1.0.weight", "module.loc_var_1.0.bias", "module.loc_var_1.1.weight", "module.loc_var_1.1.bias", "module.loc_var_1.2.weight", "module.loc_var_1.2.bias", "module.loc_var_1.3.weight", "module.loc_var_1.3.bias", "module.loc_var_1.4.weight", "module.loc_var_1.4.bias", "module.loc_var_1.5.weight", "module.loc_var_1.5.bias", "module.loc_pi_1.0.weight", "module.loc_pi_1.0.bias", "module.loc_pi_1.1.weight", "module.loc_pi_1.1.bias", "module.loc_pi_1.2.weight", "module.loc_pi_1.2.bias", "module.loc_pi_1.3.weight", "module.loc_pi_1.3.bias", "module.loc_pi_1.4.weight", "module.loc_pi_1.4.bias", "module.loc_pi_1.5.weight", "module.loc_pi_1.5.bias", "module.loc_mu_2.0.weight", "module.loc_mu_2.0.bias", "module.loc_mu_2.1.weight", "module.loc_mu_2.1.bias", "module.loc_mu_2.2.weight", "module.loc_mu_2.2.bias", "module.loc_mu_2.3.weight", "module.loc_mu_2.3.bias", "module.loc_mu_2.4.weight", "module.loc_mu_2.4.bias", "module.loc_mu_2.5.weight", "module.loc_mu_2.5.bias", "module.loc_var_2.0.weight", "module.loc_var_2.0.bias", "module.loc_var_2.1.weight", "module.loc_var_2.1.bias", "module.loc_var_2.2.weight", "module.loc_var_2.2.bias", "module.loc_var_2.3.weight", "module.loc_var_2.3.bias", "module.loc_var_2.4.weight", "module.loc_var_2.4.bias", "module.loc_var_2.5.weight", "module.loc_var_2.5.bias", "module.loc_pi_2.0.weight", "module.loc_pi_2.0.bias", "module.loc_pi_2.1.weight", "module.loc_pi_2.1.bias", "module.loc_pi_2.2.weight", "module.loc_pi_2.2.bias", "module.loc_pi_2.3.weight", "module.loc_pi_2.3.bias", "module.loc_pi_2.4.weight", "module.loc_pi_2.4.bias", "module.loc_pi_2.5.weight", "module.loc_pi_2.5.bias", "module.loc_mu_3.0.weight", "module.loc_mu_3.0.bias", "module.loc_mu_3.1.weight", "module.loc_mu_3.1.bias", "module.loc_mu_3.2.weight", "module.loc_mu_3.2.bias", "module.loc_mu_3.3.weight", "module.loc_mu_3.3.bias", "module.loc_mu_3.4.weight", "module.loc_mu_3.4.bias", "module.loc_mu_3.5.weight", "module.loc_mu_3.5.bias", "module.loc_var_3.0.weight", "module.loc_var_3.0.bias", "module.loc_var_3.1.weight", "module.loc_var_3.1.bias", "module.loc_var_3.2.weight", "module.loc_var_3.2.bias", "module.loc_var_3.3.weight", "module.loc_var_3.3.bias", "module.loc_var_3.4.weight", "module.loc_var_3.4.bias", "module.loc_var_3.5.weight", "module.loc_var_3.5.bias", "module.loc_pi_3.0.weight", "module.loc_pi_3.0.bias", "module.loc_pi_3.1.weight", "module.loc_pi_3.1.bias", "module.loc_pi_3.2.weight", "module.loc_pi_3.2.bias", "module.loc_pi_3.3.weight", "module.loc_pi_3.3.bias", "module.loc_pi_3.4.weight", "module.loc_pi_3.4.bias", "module.loc_pi_3.5.weight", "module.loc_pi_3.5.bias", "module.loc_mu_4.0.weight", "module.loc_mu_4.0.bias", "module.loc_mu_4.1.weight", "module.loc_mu_4.1.bias", "module.loc_mu_4.2.weight", "module.loc_mu_4.2.bias", "module.loc_mu_4.3.weight", "module.loc_mu_4.3.bias", "module.loc_mu_4.4.weight", "module.loc_mu_4.4.bias", "module.loc_mu_4.5.weight", "module.loc_mu_4.5.bias", "module.loc_var_4.0.weight", "module.loc_var_4.0.bias", "module.loc_var_4.1.weight", "module.loc_var_4.1.bias", "module.loc_var_4.2.weight", "module.loc_var_4.2.bias", "module.loc_var_4.3.weight", "module.loc_var_4.3.bias", "module.loc_var_4.4.weight", "module.loc_var_4.4.bias", "module.loc_var_4.5.weight", "module.loc_var_4.5.bias", "module.loc_pi_4.0.weight", "module.loc_pi_4.0.bias", "module.loc_pi_4.1.weight", "module.loc_pi_4.1.bias", "module.loc_pi_4.2.weight", "module.loc_pi_4.2.bias", "module.loc_pi_4.3.weight", "module.loc_pi_4.3.bias", "module.loc_pi_4.4.weight", "module.loc_pi_4.4.bias", "module.loc_pi_4.5.weight", "module.loc_pi_4.5.bias", "module.conf_mu_1.0.weight", "module.conf_mu_1.0.bias", "module.conf_mu_1.1.weight", "module.conf_mu_1.1.bias", "module.conf_mu_1.2.weight", "module.conf_mu_1.2.bias", "module.conf_mu_1.3.weight", "module.conf_mu_1.3.bias", "module.conf_mu_1.4.weight", "module.conf_mu_1.4.bias", "module.conf_mu_1.5.weight", "module.conf_mu_1.5.bias", "module.conf_var_1.0.weight", "module.conf_var_1.0.bias", "module.conf_var_1.1.weight", "module.conf_var_1.1.bias", "module.conf_var_1.2.weight", "module.conf_var_1.2.bias", "module.conf_var_1.3.weight", "module.conf_var_1.3.bias", "module.conf_var_1.4.weight", "module.conf_var_1.4.bias", "module.conf_var_1.5.weight", "module.conf_var_1.5.bias", "module.conf_pi_1.0.weight", "module.conf_pi_1.0.bias", "module.conf_pi_1.1.weight", "module.conf_pi_1.1.bias", "module.conf_pi_1.2.weight", "module.conf_pi_1.2.bias", "module.conf_pi_1.3.weight", "module.conf_pi_1.3.bias", "module.conf_pi_1.4.weight", "module.conf_pi_1.4.bias", "module.conf_pi_1.5.weight", "module.conf_pi_1.5.bias", "module.conf_mu_2.0.weight", "module.conf_mu_2.0.bias", "module.conf_mu_2.1.weight", "module.conf_mu_2.1.bias", "module.conf_mu_2.2.weight", "module.conf_mu_2.2.bias", "module.conf_mu_2.3.weight", "module.conf_mu_2.3.bias", "module.conf_mu_2.4.weight", "module.conf_mu_2.4.bias", "module.conf_mu_2.5.weight", "module.conf_mu_2.5.bias", "module.conf_var_2.0.weight", "module.conf_var_2.0.bias", "module.conf_var_2.1.weight", "module.conf_var_2.1.bias", "module.conf_var_2.2.weight", "module.conf_var_2.2.bias", "module.conf_var_2.3.weight", "module.conf_var_2.3.bias", "module.conf_var_2.4.weight", "module.conf_var_2.4.bias", "module.conf_var_2.5.weight", "module.conf_var_2.5.bias", "module.conf_pi_2.0.weight", "module.conf_pi_2.0.bias", "module.conf_pi_2.1.weight", "module.conf_pi_2.1.bias", "module.conf_pi_2.2.weight", "module.conf_pi_2.2.bias", "module.conf_pi_2.3.weight", "module.conf_pi_2.3.bias", "module.conf_pi_2.4.weight", "module.conf_pi_2.4.bias", "module.conf_pi_2.5.weight", "module.conf_pi_2.5.bias", "module.conf_mu_3.0.weight", "module.conf_mu_3.0.bias", "module.conf_mu_3.1.weight", "module.conf_mu_3.1.bias", "module.conf_mu_3.2.weight", "module.conf_mu_3.2.bias", "module.conf_mu_3.3.weight", "module.conf_mu_3.3.bias", "module.conf_mu_3.4.weight", "module.conf_mu_3.4.bias", "module.conf_mu_3.5.weight", "module.conf_mu_3.5.bias", "module.conf_var_3.0.weight", "module.conf_var_3.0.bias", "module.conf_var_3.1.weight", "module.conf_var_3.1.bias", "module.conf_var_3.2.weight", "module.conf_var_3.2.bias", "module.conf_var_3.3.weight", "module.conf_var_3.3.bias", "module.conf_var_3.4.weight", "module.conf_var_3.4.bias", "module.conf_var_3.5.weight", "module.conf_var_3.5.bias", "module.conf_pi_3.0.weight", "module.conf_pi_3.0.bias", "module.conf_pi_3.1.weight", "module.conf_pi_3.1.bias", "module.conf_pi_3.2.weight", "module.conf_pi_3.2.bias", "module.conf_pi_3.3.weight", "module.conf_pi_3.3.bias", "module.conf_pi_3.4.weight", "module.conf_pi_3.4.bias", "module.conf_pi_3.5.weight", "module.conf_pi_3.5.bias", "module.conf_mu_4.0.weight", "module.conf_mu_4.0.bias", "module.conf_mu_4.1.weight", "module.conf_mu_4.1.bias", "module.conf_mu_4.2.weight", "module.conf_mu_4.2.bias", "module.conf_mu_4.3.weight", "module.conf_mu_4.3.bias", "module.conf_mu_4.4.weight", "module.conf_mu_4.4.bias", "module.conf_mu_4.5.weight", "module.conf_mu_4.5.bias", "module.conf_var_4.0.weight", "module.conf_var_4.0.bias", "module.conf_var_4.1.weight", "module.conf_var_4.1.bias", "module.conf_var_4.2.weight", "module.conf_var_4.2.bias", "module.conf_var_4.3.weight", "module.conf_var_4.3.bias", "module.conf_var_4.4.weight", "module.conf_var_4.4.bias", "module.conf_var_4.5.weight", "module.conf_var_4.5.bias", "module.conf_pi_4.0.weight", "module.conf_pi_4.0.bias", "module.conf_pi_4.1.weight", "module.conf_pi_4.1.bias", "module.conf_pi_4.2.weight", "module.conf_pi_4.2.bias", "module.conf_pi_4.3.weight", "module.conf_pi_4.3.bias", "module.conf_pi_4.4.weight", "module.conf_pi_4.4.bias", "module.conf_pi_4.5.weight", "module.conf_pi_4.5.bias". Unexpected key(s) in state_dict: "0.weight", "0.bias", "2.weight", "2.bias", "5.weight", "5.bias", "7.weight", "7.bias", "10.weight", "10.bias", "12.weight", "12.bias", "14.weight", "14.bias", "17.weight", "17.bias", "19.weight", "19.bias", "21.weight", "21.bias", "24.weight", "24.bias", "26.weight", "26.bias", "28.weight", "28.bias", "31.weight", "31.bias", "33.weight", "33.bias".

stevensmile00119 commented 2 years ago

vgg16_reducedfc.pth is the pretrained weights, and you should use it to train your own weight.

The reason of this error is because of this line torch.nn.DataParallel(net) which cause state_dict is not the same. To solve the problem, you could fix the loading part into net.vgg.load_state_dict(vgg_weights)

jwchoi384 commented 2 years ago

Hello @ayennam, As @shadowtraveler said, vgg16_reducedfc.pth is pre-trained backbone weight, not the weight file of the model. To get the COCO evaluation result, you need to use your own weight trained with the COCO dataset.