Closed fbc-one closed 1 year ago
The model is a simple U-Net and should work with 128x128 inputs if you retrain it.
It seems that there are some problems in your code when loading the retrained parsing network. Possibly because you save the network without DataParallel
while the default load_weight
function expect a model wrapped with DataParallel
.
I am not sure what is the exact problem. Sorry that I am not supposed to help you with the problem, you may use some debug tools like pdb
to fix the bugs by yourself.
Thank you! I have successfully solved the problem. It was indeed caused by the mismatch of the network name saved in the pre-trained model. Currently, I am able to generate mask images with a shape of 128. Sincerely appreciate your reply!
I hope to use your code for 128x128 super-resolution images. In this case, I used the CelebAMask-HQ with a shape of 128x128 for my FPN training. After successfully completing the training, I used the following command to generate masks128 images for a 128x128 FFHQ dataset.
python generate_mask.py --gpu 1 --model parse --src_dir /data1/FFHQ128 \ --Pimg_size 128 --Gin_size 128 --Gout_size 128 --save_masks_dir /data1/maskfirst128 \ --batch_size 8 --parse_net_weight /PSFRGAN/check_points/FPN_0430/latest_net_P_without_module.pth
However, I encountered the error shown in the following image. I would like to ask, is there something wrong with the configuration of my test code? Or is it possible that your model may not be suitable for 128x128 images?
dataset [SingleDataset] was created model [ParseModel] was created Loading pretrained LQ face parsing network from /PSFRGAN/check_points/FPN_0430/latest_net_P_without_module.pth Traceback (most recent call last): File "generate_mask.py", line 21, in <module> model.load_pretrain_models() File "/PSFRGAN/models/parse_model.py", line 54, in load_pretrain_models self.netP.load_state_dict(torch.load(self.opt.parse_net_weight)) File "/anaconda3/envs/PSFR/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1672, in load_state_dict self.__class__.__name__, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: "module.encoder.0.conv2d.weight", "module.encoder.0.conv2d.bias", "module.encoder.1.shortcut_func.conv2d.weight", "module.encoder.1.shortcut_func.conv2d.bias", "module.encoder.1.conv1.conv2d.weight", "module.encoder.1.conv1.norm.norm.weight", "module.encoder.1.conv1.norm.norm.bias", "module.encoder.1.conv1.norm.norm.running_mean", "module.encoder.1.conv1.norm.norm.running_var", "module.encoder.1.conv2.conv2d.weight", "module.encoder.1.conv2.norm.norm.weight", "module.encoder.1.conv2.norm.norm.bias", "module.encoder.1.conv2.norm.norm.running_mean", "module.encoder.1.conv2.norm.norm.running_var", "module.encoder.2.shortcut_func.conv2d.weight", "module.encoder.2.shortcut_func.conv2d.bias", "module.encoder.2.conv1.conv2d.weight", "module.encoder.2.conv1.norm.norm.weight", "module.encoder.2.conv1.norm.norm.bias", "module.encoder.2.conv1.norm.norm.running_mean", "module.encoder.2.conv1.norm.norm.running_var", "module.encoder.2.conv2.conv2d.weight", "module.encoder.2.conv2.norm.norm.weight", "module.encoder.2.conv2.norm.norm.bias", "module.encoder.2.conv2.norm.norm.running_mean", "module.encoder.2.conv2.norm.norm.running_var", "module.encoder.3.shortcut_func.conv2d.weight", "module.encoder.3.shortcut_func.conv2d.bias", "module.encoder.3.conv1.conv2d.weight", "module.encoder.3.conv1.norm.norm.weight", "module.encoder.3.conv1.norm.norm.bias", "module.encoder.3.conv1.norm.norm.running_mean", "module.encoder.3.conv1.norm.norm.running_var", "module.encoder.3.conv2.conv2d.weight", "module.encoder.3.conv2.norm.norm.weight", "module.encoder.3.conv2.norm.norm.bias", "module.encoder.3.conv2.norm.norm.running_mean", "module.encoder.3.conv2.norm.norm.running_var", "module.encoder.4.shortcut_func.conv2d.weight", "module.encoder.4.shortcut_func.conv2d.bias", "module.encoder.4.conv1.conv2d.weight", "module.encoder.4.conv1.norm.norm.weight", "module.encoder.4.conv1.norm.norm.bias", "module.encoder.4.conv1.norm.norm.running_mean", "module.encoder.4.conv1.norm.norm.running_var", "module.encoder.4.conv2.conv2d.weight", "module.encoder.4.conv2.norm.norm.weight", "module.encoder.4.conv2.norm.norm.bias", "module.encoder.4.conv2.norm.norm.running_mean", "module.encoder.4.conv2.norm.norm.running_var", "module.body.0.conv1.conv2d.weight", "module.body.0.conv1.norm.norm.weight", "module.body.0.conv1.norm.norm.bias", "module.body.0.conv1.norm.norm.running_mean", "module.body.0.conv1.norm.norm.running_var", "module.body.0.conv2.conv2d.weight", "module.body.0.conv2.norm.norm.weight", "module.body.0.conv2.norm.norm.bias", "module.body.0.conv2.norm.norm.running_mean", "module.body.0.conv2.norm.norm.running_var", "module.body.1.conv1.conv2d.weight", "module.body.1.conv1.norm.norm.weight", "module.body.1.conv1.norm.norm.bias", "module.body.1.conv1.norm.norm.running_mean", "module.body.1.conv1.norm.norm.running_var", "module.body.1.conv2.conv2d.weight", "module.body.1.conv2.norm.norm.weight", "module.body.1.conv2.norm.norm.bias", "module.body.1.conv2.norm.norm.running_mean", "module.body.1.conv2.norm.norm.running_var", "module.body.2.conv1.conv2d.weight", "module.body.2.conv1.norm.norm.weight", "module.body.2.conv1.norm.norm.bias", "module.body.2.conv1.norm.norm.running_mean", "module.body.2.conv1.norm.norm.running_var", "module.body.2.conv2.conv2d.weight", "module.body.2.conv2.norm.norm.weight", "module.body.2.conv2.norm.norm.bias", "module.body.2.conv2.norm.norm.running_mean", "module.body.2.conv2.norm.norm.running_var", "module.body.3.conv1.conv2d.weight", "module.body.3.conv1.norm.norm.weight", "module.body.3.conv1.norm.norm.bias", "module.body.3.conv1.norm.norm.running_mean", "module.body.3.conv1.norm.norm.running_var", "module.body.3.conv2.conv2d.weight", "module.body.3.conv2.norm.norm.weight", "module.body.3.conv2.norm.norm.bias", "module.body.3.conv2.norm.norm.running_mean", "module.body.3.conv2.norm.norm.running_var", "module.body.4.conv1.conv2d.weight", "module.body.4.conv1.norm.norm.weight", "module.body.4.conv1.norm.norm.bias", "module.body.4.conv1.norm.norm.running_mean", "module.body.4.conv1.norm.norm.running_var", "module.body.4.conv2.conv2d.weight", "module.body.4.conv2.norm.norm.weight", "module.body.4.conv2.norm.norm.bias", "module.body.4.conv2.norm.norm.running_mean", "module.body.4.conv2.norm.norm.running_var", "module.body.5.conv1.conv2d.weight", "module.body.5.conv1.norm.norm.weight", "module.body.5.conv1.norm.norm.bias", "module.body.5.conv1.norm.norm.running_mean", "module.body.5.conv1.norm.norm.running_var", "module.body.5.conv2.conv2d.weight", "module.body.5.conv2.norm.norm.weight", "module.body.5.conv2.norm.norm.bias", "module.body.5.conv2.norm.norm.running_mean", "module.body.5.conv2.norm.norm.running_var", "module.body.6.conv1.conv2d.weight", "module.body.6.conv1.norm.norm.weight", "module.body.6.conv1.norm.norm.bias", "module.body.6.conv1.norm.norm.running_mean", "module.body.6.conv1.norm.norm.running_var", "module.body.6.conv2.conv2d.weight", "module.body.6.conv2.norm.norm.weight", "module.body.6.conv2.norm.norm.bias", "module.body.6.conv2.norm.norm.running_mean", "module.body.6.conv2.norm.norm.running_var", "module.body.7.conv1.conv2d.weight", "module.body.7.conv1.norm.norm.weight", "module.body.7.conv1.norm.norm.bias", "module.body.7.conv1.norm.norm.running_mean", "module.body.7.conv1.norm.norm.running_var", "module.body.7.conv2.conv2d.weight", "module.body.7.conv2.norm.norm.weight", "module.body.7.conv2.norm.norm.bias", "module.body.7.conv2.norm.norm.running_mean", "module.body.7.conv2.norm.norm.running_var", "module.body.8.conv1.conv2d.weight", "module.body.8.conv1.norm.norm.weight", "module.body.8.conv1.norm.norm.bias", "module.body.8.conv1.norm.norm.running_mean", "module.body.8.conv1.norm.norm.running_var", "module.body.8.conv2.conv2d.weight", "module.body.8.conv2.norm.norm.weight", "module.body.8.conv2.norm.norm.bias", "module.body.8.conv2.norm.norm.running_mean", "module.body.8.conv2.norm.norm.running_var", "module.body.9.conv1.conv2d.weight", "module.body.9.conv1.norm.norm.weight", "module.body.9.conv1.norm.norm.bias", "module.body.9.conv1.norm.norm.running_mean", "module.body.9.conv1.norm.norm.running_var", "module.body.9.conv2.conv2d.weight", "module.body.9.conv2.norm.norm.weight", "module.body.9.conv2.norm.norm.bias", "module.body.9.conv2.norm.norm.running_mean", "module.body.9.conv2.norm.norm.running_var", "module.decoder.0.shortcut_func.conv2d.weight", "module.decoder.0.shortcut_func.conv2d.bias", "module.decoder.0.conv1.conv2d.weight", "module.decoder.0.conv1.norm.norm.weight", "module.decoder.0.conv1.norm.norm.bias", "module.decoder.0.conv1.norm.norm.running_mean", "module.decoder.0.conv1.norm.norm.running_var", "module.decoder.0.conv2.conv2d.weight", "module.decoder.0.conv2.norm.norm.weight", "module.decoder.0.conv2.norm.norm.bias", "module.decoder.0.conv2.norm.norm.running_mean", "module.decoder.0.conv2.norm.norm.running_var", "module.decoder.1.shortcut_func.conv2d.weight", "module.decoder.1.shortcut_func.conv2d.bias", "module.decoder.1.conv1.conv2d.weight", "module.decoder.1.conv1.norm.norm.weight", "module.decoder.1.conv1.norm.norm.bias", "module.decoder.1.conv1.norm.norm.running_mean", "module.decoder.1.conv1.norm.norm.running_var", "module.decoder.1.conv2.conv2d.weight", "module.decoder.1.conv2.norm.norm.weight", "module.decoder.1.conv2.norm.norm.bias", "module.decoder.1.conv2.norm.norm.running_mean", "module.decoder.1.conv2.norm.norm.running_var", "module.decoder.2.shortcut_func.conv2d.weight", "module.decoder.2.shortcut_func.conv2d.bias", "module.decoder.2.conv1.conv2d.weight", "module.decoder.2.conv1.norm.norm.weight", "module.decoder.2.conv1.norm.norm.bias", "module.decoder.2.conv1.norm.norm.running_mean", "module.decoder.2.conv1.norm.norm.running_var", "module.decoder.2.conv2.conv2d.weight", "module.decoder.2.conv2.norm.norm.weight", "module.decoder.2.conv2.norm.norm.bias", "module.decoder.2.conv2.norm.norm.running_mean", "module.decoder.2.conv2.norm.norm.running_var", "module.decoder.3.shortcut_func.conv2d.weight", "module.decoder.3.shortcut_func.conv2d.bias", "module.decoder.3.conv1.conv2d.weight", "module.decoder.3.conv1.norm.norm.weight", "module.decoder.3.conv1.norm.norm.bias", "module.decoder.3.conv1.norm.norm.running_mean", "module.decoder.3.conv1.norm.norm.running_var", "module.decoder.3.conv2.conv2d.weight", "module.decoder.3.conv2.norm.norm.weight", "module.decoder.3.conv2.norm.norm.bias", "module.decoder.3.conv2.norm.norm.running_mean", "module.decoder.3.conv2.norm.norm.running_var", "module.out_img_conv.conv2d.weight", "module.out_img_conv.conv2d.bias", "module.out_mask_conv.conv2d.weight", "module.out_mask_conv.conv2d.bias". Unexpected key(s) in state_dict: "encoder.0.conv2d.weight", "encoder.0.conv2d.bias", "encoder.1.shortcut_func.conv2d.weight", "encoder.1.shortcut_func.conv2d.bias", "encoder.1.conv1.conv2d.weight", "encoder.1.conv1.norm.norm.weight", "encoder.1.conv1.norm.norm.bias", "encoder.1.conv1.norm.norm.running_mean", "encoder.1.conv1.norm.norm.running_var", "encoder.1.conv1.norm.norm.num_batches_tracked", "encoder.1.conv2.conv2d.weight", "encoder.1.conv2.norm.norm.weight", "encoder.1.conv2.norm.norm.bias", "encoder.1.conv2.norm.norm.running_mean", "encoder.1.conv2.norm.norm.running_var", "encoder.1.conv2.norm.norm.num_batches_tracked", "encoder.2.shortcut_func.conv2d.weight", "encoder.2.shortcut_func.conv2d.bias", "encoder.2.conv1.conv2d.weight", "encoder.2.conv1.norm.norm.weight", "encoder.2.conv1.norm.norm.bias", "encoder.2.conv1.norm.norm.running_mean", "encoder.2.conv1.norm.norm.running_var", "encoder.2.conv1.norm.norm.num_batches_tracked", "encoder.2.conv2.conv2d.weight", "encoder.2.conv2.norm.norm.weight", "encoder.2.conv2.norm.norm.bias", "encoder.2.conv2.norm.norm.running_mean", "encoder.2.conv2.norm.norm.running_var", "encoder.2.conv2.norm.norm.num_batches_tracked", "encoder.3.shortcut_func.conv2d.weight", "encoder.3.shortcut_func.conv2d.bias", "encoder.3.conv1.conv2d.weight", "encoder.3.conv1.norm.norm.weight", "encoder.3.conv1.norm.norm.bias", "encoder.3.conv1.norm.norm.running_mean", "encoder.3.conv1.norm.norm.running_var", "encoder.3.conv1.norm.norm.num_batches_tracked", "encoder.3.conv2.conv2d.weight", "encoder.3.conv2.norm.norm.weight", "encoder.3.conv2.norm.norm.bias", "encoder.3.conv2.norm.norm.running_mean", "encoder.3.conv2.norm.norm.running_var", "encoder.3.conv2.norm.norm.num_batches_tracked", "encoder.4.shortcut_func.conv2d.weight", "encoder.4.shortcut_func.conv2d.bias", "encoder.4.conv1.conv2d.weight", "encoder.4.conv1.norm.norm.weight", "encoder.4.conv1.norm.norm.bias", "encoder.4.conv1.norm.norm.running_mean", "encoder.4.conv1.norm.norm.running_var", "encoder.4.conv1.norm.norm.num_batches_tracked", "encoder.4.conv2.conv2d.weight", "encoder.4.conv2.norm.norm.weight", "encoder.4.conv2.norm.norm.bias", "encoder.4.conv2.norm.norm.running_mean", "encoder.4.conv2.norm.norm.running_var", "encoder.4.conv2.norm.norm.num_batches_tracked", "body.0.conv1.conv2d.weight", "body.0.conv1.norm.norm.weight", "body.0.conv1.norm.norm.bias", "body.0.conv1.norm.norm.running_mean", "body.0.conv1.norm.norm.running_var", "body.0.conv1.norm.norm.num_batches_tracked", "body.0.conv2.conv2d.weight", "body.0.conv2.norm.norm.weight", "body.0.conv2.norm.norm.bias", "body.0.conv2.norm.norm.running_mean", "body.0.conv2.norm.norm.running_var", "body.0.conv2.norm.norm.num_batches_tracked", "body.1.conv1.conv2d.weight", "body.1.conv1.norm.norm.weight", "body.1.conv1.norm.norm.bias", "body.1.conv1.norm.norm.running_mean", "body.1.conv1.norm.norm.running_var", "body.1.conv1.norm.norm.num_batches_tracked", "body.1.conv2.conv2d.weight", "body.1.conv2.norm.norm.weight", "body.1.conv2.norm.norm.bias", "body.1.conv2.norm.norm.running_mean", "body.1.conv2.norm.norm.running_var", "body.1.conv2.norm.norm.num_batches_tracked", "body.2.conv1.conv2d.weight", "body.2.conv1.norm.norm.weight", "body.2.conv1.norm.norm.bias", "body.2.conv1.norm.norm.running_mean", "body.2.conv1.norm.norm.running_var", "body.2.conv1.norm.norm.num_batches_tracked", "body.2.conv2.conv2d.weight", "body.2.conv2.norm.norm.weight", "body.2.conv2.norm.norm.bias", "body.2.conv2.norm.norm.running_mean", "body.2.conv2.norm.norm.running_var", "body.2.conv2.norm.norm.num_batches_tracked", "body.3.conv1.conv2d.weight", "body.3.conv1.norm.norm.weight", "body.3.conv1.norm.norm.bias", "body.3.conv1.norm.norm.running_mean", "body.3.conv1.norm.norm.running_var", "body.3.conv1.norm.norm.num_batches_tracked", "body.3.conv2.conv2d.weight", "body.3.conv2.norm.norm.weight", "body.3.conv2.norm.norm.bias", "body.3.conv2.norm.norm.running_mean", "body.3.conv2.norm.norm.running_var", "body.3.conv2.norm.norm.num_batches_tracked", "body.4.conv1.conv2d.weight", "body.4.conv1.norm.norm.weight", "body.4.conv1.norm.norm.bias", "body.4.conv1.norm.norm.running_mean", "body.4.conv1.norm.norm.running_var", "body.4.conv1.norm.norm.num_batches_tracked", "body.4.conv2.conv2d.weight", "body.4.conv2.norm.norm.weight", "body.4.conv2.norm.norm.bias", "body.4.conv2.norm.norm.running_mean", "body.4.conv2.norm.norm.running_var", "body.4.conv2.norm.norm.num_batches_tracked", "body.5.conv1.conv2d.weight", "body.5.conv1.norm.norm.weight", "body.5.conv1.norm.norm.bias", "body.5.conv1.norm.norm.running_mean", "body.5.conv1.norm.norm.running_var", "body.5.conv1.norm.norm.num_batches_tracked", "body.5.conv2.conv2d.weight", "body.5.conv2.norm.norm.weight", "body.5.conv2.norm.norm.bias", "body.5.conv2.norm.norm.running_mean", "body.5.conv2.norm.norm.running_var", "body.5.conv2.norm.norm.num_batches_tracked", "body.6.conv1.conv2d.weight", "body.6.conv1.norm.norm.weight", "body.6.conv1.norm.norm.bias", "body.6.conv1.norm.norm.running_mean", "body.6.conv1.norm.norm.running_var", "body.6.conv1.norm.norm.num_batches_tracked", "body.6.conv2.conv2d.weight", "body.6.conv2.norm.norm.weight", "body.6.conv2.norm.norm.bias", "body.6.conv2.norm.norm.running_mean", "body.6.conv2.norm.norm.running_var", "body.6.conv2.norm.norm.num_batches_tracked", "body.7.conv1.conv2d.weight", "body.7.conv1.norm.norm.weight", "body.7.conv1.norm.norm.bias", "body.7.conv1.norm.norm.running_mean", "body.7.conv1.norm.norm.running_var", "body.7.conv1.norm.norm.num_batches_tracked", "body.7.conv2.conv2d.weight", "body.7.conv2.norm.norm.weight", "body.7.conv2.norm.norm.bias", "body.7.conv2.norm.norm.running_mean", "body.7.conv2.norm.norm.running_var", "body.7.conv2.norm.norm.num_batches_tracked", "body.8.conv1.conv2d.weight", "body.8.conv1.norm.norm.weight", "body.8.conv1.norm.norm.bias", "body.8.conv1.norm.norm.running_mean", "body.8.conv1.norm.norm.running_var", "body.8.conv1.norm.norm.num_batches_tracked", "body.8.conv2.conv2d.weight", "body.8.conv2.norm.norm.weight", "body.8.conv2.norm.norm.bias", "body.8.conv2.norm.norm.running_mean", "body.8.conv2.norm.norm.running_var", "body.8.conv2.norm.norm.num_batches_tracked", "body.9.conv1.conv2d.weight", "body.9.conv1.norm.norm.weight", "body.9.conv1.norm.norm.bias", "body.9.conv1.norm.norm.running_mean", "body.9.conv1.norm.norm.running_var", "body.9.conv1.norm.norm.num_batches_tracked", "body.9.conv2.conv2d.weight", "body.9.conv2.norm.norm.weight", "body.9.conv2.norm.norm.bias", "body.9.conv2.norm.norm.running_mean", "body.9.conv2.norm.norm.running_var", "body.9.conv2.norm.norm.num_batches_tracked", "decoder.0.shortcut_func.conv2d.weight", "decoder.0.shortcut_func.conv2d.bias", "decoder.0.conv1.conv2d.weight", "decoder.0.conv1.norm.norm.weight", "decoder.0.conv1.norm.norm.bias", "decoder.0.conv1.norm.norm.running_mean", "decoder.0.conv1.norm.norm.running_var", "decoder.0.conv1.norm.norm.num_batches_tracked", "decoder.0.conv2.conv2d.weight", "decoder.0.conv2.norm.norm.weight", "decoder.0.conv2.norm.norm.bias", "decoder.0.conv2.norm.norm.running_mean", "decoder.0.conv2.norm.norm.running_var", "decoder.0.conv2.norm.norm.num_batches_tracked", "decoder.1.shortcut_func.conv2d.weight", "decoder.1.shortcut_func.conv2d.bias", "decoder.1.conv1.conv2d.weight", "decoder.1.conv1.norm.norm.weight", "decoder.1.conv1.norm.norm.bias", "decoder.1.conv1.norm.norm.running_mean", "decoder.1.conv1.norm.norm.running_var", "decoder.1.conv1.norm.norm.num_batches_tracked", "decoder.1.conv2.conv2d.weight", "decoder.1.conv2.norm.norm.weight", "decoder.1.conv2.norm.norm.bias", "decoder.1.conv2.norm.norm.running_mean", "decoder.1.conv2.norm.norm.running_var", "decoder.1.conv2.norm.norm.num_batches_tracked", "decoder.2.shortcut_func.conv2d.weight", "decoder.2.shortcut_func.conv2d.bias", "decoder.2.conv1.conv2d.weight", "decoder.2.conv1.norm.norm.weight", "decoder.2.conv1.norm.norm.bias", "decoder.2.conv1.norm.norm.running_mean", "decoder.2.conv1.norm.norm.running_var", "decoder.2.conv1.norm.norm.num_batches_tracked", "decoder.2.conv2.conv2d.weight", "decoder.2.conv2.norm.norm.weight", "decoder.2.conv2.norm.norm.bias", "decoder.2.conv2.norm.norm.running_mean", "decoder.2.conv2.norm.norm.running_var", "decoder.2.conv2.norm.norm.num_batches_tracked", "decoder.3.shortcut_func.conv2d.weight", "decoder.3.shortcut_func.conv2d.bias", "decoder.3.conv1.conv2d.weight", "decoder.3.conv1.norm.norm.weight", "decoder.3.conv1.norm.norm.bias", "decoder.3.conv1.norm.norm.running_mean", "decoder.3.conv1.norm.norm.running_var", "decoder.3.conv1.norm.norm.num_batches_tracked", "decoder.3.conv2.conv2d.weight", "decoder.3.conv2.norm.norm.weight", "decoder.3.conv2.norm.norm.bias", "decoder.3.conv2.norm.norm.running_mean", "decoder.3.conv2.norm.norm.running_var", "decoder.3.conv2.norm.norm.num_batches_tracked", "out_img_conv.conv2d.weight", "out_img_conv.conv2d.bias", "out_mask_conv.conv2d.weight", "out_mask_conv.conv2d.bias".
By the way, when I run the generate_mask.py code directly using the FPN I trained as mentioned above, the generated image looks like this.
Looking forward to your reply!