xiaogang00 / Adversarial_Distribution_Alignment

The project for General Adversarial Defense Against Black-box Attacks via Pixel Level and Feature Level Distribution Alignments.
1 stars 0 forks source link

Pretrained model weights mismatch #1

Open Jtao0818 opened 1 year ago

Jtao0818 commented 1 year ago

Thanks for your excellent work!

I am trying to train the semantic segmentation on VOC2012 with your pretrained weights. An error occurs when loading the weights, I suppose it is due to different model settings between the weights and the codes. Could you check it?

The command is python train_voc2012.py --name voc_model_pspnet --no_instance --label_nc 22 --tf_log --model pix2pixHD_voc2012

The log is following:

Traceback (most recent call last): File "train_voc2012.py", line 41, in model = create_model(opt) File "/media/sse1080/D/ParcharmProject/Adversarial_Distribution_Alignment-main/models/models.py", line 42, in create_model model.initialize(opt) File "/media/sse1080/D/ParcharmProject/Adversarial_Distribution_Alignment-main/models/pix2pixHD_model_voc2012.py", line 132, in initialize seg_model.load_state_dict(checkpoint['state_dict'], strict=False) File "/home/sse1080/anaconda3/envs/adv_align/lib/python3.6/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for DataParallel: size mismatch for module.cls.0.weight: copying a param with shape torch.Size([512, 4096, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1280, 1, 1]). size mismatch for module.cls.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for module.cls.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for module.cls.1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for module.cls.1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]). size mismatch for module.cls.4.weight: copying a param with shape torch.Size([21, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([21, 256, 1, 1]). size mismatch for module.aux.0.weight: copying a param with shape torch.Size([256, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]).

honhaochen commented 6 months ago

change this line to pspnet.PSPNet