chaofengc / PSFRGAN

PyTorch codes for "Progressive Semantic-Aware Style Transformation for Blind Face Restoration", CVPR2021
Other
370 stars 68 forks source link

Some trouble when using 128x128 shape of image. #58

Closed fbc-one closed 1 year ago

fbc-one commented 1 year ago

I hope to use your code for 128x128 super-resolution images. In this case, I used the CelebAMask-HQ with a shape of 128x128 for my FPN training. After successfully completing the training, I used the following command to generate masks128 images for a 128x128 FFHQ dataset. python generate_mask.py --gpu 1 --model parse --src_dir /data1/FFHQ128 \ --Pimg_size 128 --Gin_size 128 --Gout_size 128 --save_masks_dir /data1/maskfirst128 \ --batch_size 8 --parse_net_weight /PSFRGAN/check_points/FPN_0430/latest_net_P_without_module.pth

However, I encountered the error shown in the following image. I would like to ask, is there something wrong with the configuration of my test code? Or is it possible that your model may not be suitable for 128x128 images? dataset [SingleDataset] was created model [ParseModel] was created Loading pretrained LQ face parsing network from /PSFRGAN/check_points/FPN_0430/latest_net_P_without_module.pth Traceback (most recent call last): File "generate_mask.py", line 21, in <module> model.load_pretrain_models() File "/PSFRGAN/models/parse_model.py", line 54, in load_pretrain_models self.netP.load_state_dict(torch.load(self.opt.parse_net_weight)) File "/anaconda3/envs/PSFR/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1672, in load_state_dict self.__class__.__name__, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: "module.encoder.0.conv2d.weight", "module.encoder.0.conv2d.bias", "module.encoder.1.shortcut_func.conv2d.weight", "module.encoder.1.shortcut_func.conv2d.bias", "module.encoder.1.conv1.conv2d.weight", "module.encoder.1.conv1.norm.norm.weight", "module.encoder.1.conv1.norm.norm.bias", "module.encoder.1.conv1.norm.norm.running_mean", "module.encoder.1.conv1.norm.norm.running_var", "module.encoder.1.conv2.conv2d.weight", "module.encoder.1.conv2.norm.norm.weight", "module.encoder.1.conv2.norm.norm.bias", "module.encoder.1.conv2.norm.norm.running_mean", "module.encoder.1.conv2.norm.norm.running_var", "module.encoder.2.shortcut_func.conv2d.weight", "module.encoder.2.shortcut_func.conv2d.bias", "module.encoder.2.conv1.conv2d.weight", "module.encoder.2.conv1.norm.norm.weight", "module.encoder.2.conv1.norm.norm.bias", "module.encoder.2.conv1.norm.norm.running_mean", "module.encoder.2.conv1.norm.norm.running_var", "module.encoder.2.conv2.conv2d.weight", "module.encoder.2.conv2.norm.norm.weight", "module.encoder.2.conv2.norm.norm.bias", "module.encoder.2.conv2.norm.norm.running_mean", "module.encoder.2.conv2.norm.norm.running_var", "module.encoder.3.shortcut_func.conv2d.weight", "module.encoder.3.shortcut_func.conv2d.bias", "module.encoder.3.conv1.conv2d.weight", "module.encoder.3.conv1.norm.norm.weight", "module.encoder.3.conv1.norm.norm.bias", "module.encoder.3.conv1.norm.norm.running_mean", "module.encoder.3.conv1.norm.norm.running_var", "module.encoder.3.conv2.conv2d.weight", "module.encoder.3.conv2.norm.norm.weight", "module.encoder.3.conv2.norm.norm.bias", "module.encoder.3.conv2.norm.norm.running_mean", "module.encoder.3.conv2.norm.norm.running_var", "module.encoder.4.shortcut_func.conv2d.weight", "module.encoder.4.shortcut_func.conv2d.bias", "module.encoder.4.conv1.conv2d.weight", "module.encoder.4.conv1.norm.norm.weight", "module.encoder.4.conv1.norm.norm.bias", "module.encoder.4.conv1.norm.norm.running_mean", "module.encoder.4.conv1.norm.norm.running_var", "module.encoder.4.conv2.conv2d.weight", "module.encoder.4.conv2.norm.norm.weight", "module.encoder.4.conv2.norm.norm.bias", "module.encoder.4.conv2.norm.norm.running_mean", "module.encoder.4.conv2.norm.norm.running_var", "module.body.0.conv1.conv2d.weight", "module.body.0.conv1.norm.norm.weight", "module.body.0.conv1.norm.norm.bias", "module.body.0.conv1.norm.norm.running_mean", "module.body.0.conv1.norm.norm.running_var", "module.body.0.conv2.conv2d.weight", "module.body.0.conv2.norm.norm.weight", "module.body.0.conv2.norm.norm.bias", "module.body.0.conv2.norm.norm.running_mean", "module.body.0.conv2.norm.norm.running_var", "module.body.1.conv1.conv2d.weight", "module.body.1.conv1.norm.norm.weight", "module.body.1.conv1.norm.norm.bias", "module.body.1.conv1.norm.norm.running_mean", "module.body.1.conv1.norm.norm.running_var", "module.body.1.conv2.conv2d.weight", "module.body.1.conv2.norm.norm.weight", "module.body.1.conv2.norm.norm.bias", "module.body.1.conv2.norm.norm.running_mean", "module.body.1.conv2.norm.norm.running_var", "module.body.2.conv1.conv2d.weight", "module.body.2.conv1.norm.norm.weight", "module.body.2.conv1.norm.norm.bias", "module.body.2.conv1.norm.norm.running_mean", "module.body.2.conv1.norm.norm.running_var", "module.body.2.conv2.conv2d.weight", "module.body.2.conv2.norm.norm.weight", "module.body.2.conv2.norm.norm.bias", "module.body.2.conv2.norm.norm.running_mean", "module.body.2.conv2.norm.norm.running_var", "module.body.3.conv1.conv2d.weight", "module.body.3.conv1.norm.norm.weight", "module.body.3.conv1.norm.norm.bias", "module.body.3.conv1.norm.norm.running_mean", "module.body.3.conv1.norm.norm.running_var", "module.body.3.conv2.conv2d.weight", "module.body.3.conv2.norm.norm.weight", "module.body.3.conv2.norm.norm.bias", "module.body.3.conv2.norm.norm.running_mean", "module.body.3.conv2.norm.norm.running_var", "module.body.4.conv1.conv2d.weight", "module.body.4.conv1.norm.norm.weight", "module.body.4.conv1.norm.norm.bias", "module.body.4.conv1.norm.norm.running_mean", "module.body.4.conv1.norm.norm.running_var", "module.body.4.conv2.conv2d.weight", "module.body.4.conv2.norm.norm.weight", "module.body.4.conv2.norm.norm.bias", "module.body.4.conv2.norm.norm.running_mean", "module.body.4.conv2.norm.norm.running_var", "module.body.5.conv1.conv2d.weight", "module.body.5.conv1.norm.norm.weight", "module.body.5.conv1.norm.norm.bias", "module.body.5.conv1.norm.norm.running_mean", "module.body.5.conv1.norm.norm.running_var", "module.body.5.conv2.conv2d.weight", "module.body.5.conv2.norm.norm.weight", "module.body.5.conv2.norm.norm.bias", "module.body.5.conv2.norm.norm.running_mean", "module.body.5.conv2.norm.norm.running_var", "module.body.6.conv1.conv2d.weight", "module.body.6.conv1.norm.norm.weight", "module.body.6.conv1.norm.norm.bias", "module.body.6.conv1.norm.norm.running_mean", "module.body.6.conv1.norm.norm.running_var", "module.body.6.conv2.conv2d.weight", "module.body.6.conv2.norm.norm.weight", "module.body.6.conv2.norm.norm.bias", "module.body.6.conv2.norm.norm.running_mean", "module.body.6.conv2.norm.norm.running_var", "module.body.7.conv1.conv2d.weight", "module.body.7.conv1.norm.norm.weight", "module.body.7.conv1.norm.norm.bias", "module.body.7.conv1.norm.norm.running_mean", "module.body.7.conv1.norm.norm.running_var", "module.body.7.conv2.conv2d.weight", "module.body.7.conv2.norm.norm.weight", "module.body.7.conv2.norm.norm.bias", "module.body.7.conv2.norm.norm.running_mean", "module.body.7.conv2.norm.norm.running_var", "module.body.8.conv1.conv2d.weight", "module.body.8.conv1.norm.norm.weight", "module.body.8.conv1.norm.norm.bias", "module.body.8.conv1.norm.norm.running_mean", "module.body.8.conv1.norm.norm.running_var", "module.body.8.conv2.conv2d.weight", "module.body.8.conv2.norm.norm.weight", "module.body.8.conv2.norm.norm.bias", "module.body.8.conv2.norm.norm.running_mean", "module.body.8.conv2.norm.norm.running_var", "module.body.9.conv1.conv2d.weight", "module.body.9.conv1.norm.norm.weight", "module.body.9.conv1.norm.norm.bias", "module.body.9.conv1.norm.norm.running_mean", "module.body.9.conv1.norm.norm.running_var", "module.body.9.conv2.conv2d.weight", "module.body.9.conv2.norm.norm.weight", "module.body.9.conv2.norm.norm.bias", "module.body.9.conv2.norm.norm.running_mean", "module.body.9.conv2.norm.norm.running_var", "module.decoder.0.shortcut_func.conv2d.weight", "module.decoder.0.shortcut_func.conv2d.bias", "module.decoder.0.conv1.conv2d.weight", "module.decoder.0.conv1.norm.norm.weight", "module.decoder.0.conv1.norm.norm.bias", "module.decoder.0.conv1.norm.norm.running_mean", "module.decoder.0.conv1.norm.norm.running_var", "module.decoder.0.conv2.conv2d.weight", "module.decoder.0.conv2.norm.norm.weight", "module.decoder.0.conv2.norm.norm.bias", "module.decoder.0.conv2.norm.norm.running_mean", "module.decoder.0.conv2.norm.norm.running_var", "module.decoder.1.shortcut_func.conv2d.weight", "module.decoder.1.shortcut_func.conv2d.bias", "module.decoder.1.conv1.conv2d.weight", "module.decoder.1.conv1.norm.norm.weight", "module.decoder.1.conv1.norm.norm.bias", "module.decoder.1.conv1.norm.norm.running_mean", "module.decoder.1.conv1.norm.norm.running_var", "module.decoder.1.conv2.conv2d.weight", "module.decoder.1.conv2.norm.norm.weight", "module.decoder.1.conv2.norm.norm.bias", "module.decoder.1.conv2.norm.norm.running_mean", "module.decoder.1.conv2.norm.norm.running_var", "module.decoder.2.shortcut_func.conv2d.weight", "module.decoder.2.shortcut_func.conv2d.bias", "module.decoder.2.conv1.conv2d.weight", "module.decoder.2.conv1.norm.norm.weight", "module.decoder.2.conv1.norm.norm.bias", "module.decoder.2.conv1.norm.norm.running_mean", "module.decoder.2.conv1.norm.norm.running_var", "module.decoder.2.conv2.conv2d.weight", "module.decoder.2.conv2.norm.norm.weight", "module.decoder.2.conv2.norm.norm.bias", "module.decoder.2.conv2.norm.norm.running_mean", "module.decoder.2.conv2.norm.norm.running_var", "module.decoder.3.shortcut_func.conv2d.weight", "module.decoder.3.shortcut_func.conv2d.bias", "module.decoder.3.conv1.conv2d.weight", "module.decoder.3.conv1.norm.norm.weight", "module.decoder.3.conv1.norm.norm.bias", "module.decoder.3.conv1.norm.norm.running_mean", "module.decoder.3.conv1.norm.norm.running_var", "module.decoder.3.conv2.conv2d.weight", "module.decoder.3.conv2.norm.norm.weight", "module.decoder.3.conv2.norm.norm.bias", "module.decoder.3.conv2.norm.norm.running_mean", "module.decoder.3.conv2.norm.norm.running_var", "module.out_img_conv.conv2d.weight", "module.out_img_conv.conv2d.bias", "module.out_mask_conv.conv2d.weight", "module.out_mask_conv.conv2d.bias". Unexpected key(s) in state_dict: "encoder.0.conv2d.weight", "encoder.0.conv2d.bias", "encoder.1.shortcut_func.conv2d.weight", "encoder.1.shortcut_func.conv2d.bias", "encoder.1.conv1.conv2d.weight", "encoder.1.conv1.norm.norm.weight", "encoder.1.conv1.norm.norm.bias", "encoder.1.conv1.norm.norm.running_mean", "encoder.1.conv1.norm.norm.running_var", "encoder.1.conv1.norm.norm.num_batches_tracked", "encoder.1.conv2.conv2d.weight", "encoder.1.conv2.norm.norm.weight", "encoder.1.conv2.norm.norm.bias", "encoder.1.conv2.norm.norm.running_mean", "encoder.1.conv2.norm.norm.running_var", "encoder.1.conv2.norm.norm.num_batches_tracked", "encoder.2.shortcut_func.conv2d.weight", "encoder.2.shortcut_func.conv2d.bias", "encoder.2.conv1.conv2d.weight", "encoder.2.conv1.norm.norm.weight", "encoder.2.conv1.norm.norm.bias", "encoder.2.conv1.norm.norm.running_mean", "encoder.2.conv1.norm.norm.running_var", "encoder.2.conv1.norm.norm.num_batches_tracked", "encoder.2.conv2.conv2d.weight", "encoder.2.conv2.norm.norm.weight", "encoder.2.conv2.norm.norm.bias", "encoder.2.conv2.norm.norm.running_mean", "encoder.2.conv2.norm.norm.running_var", "encoder.2.conv2.norm.norm.num_batches_tracked", "encoder.3.shortcut_func.conv2d.weight", "encoder.3.shortcut_func.conv2d.bias", "encoder.3.conv1.conv2d.weight", "encoder.3.conv1.norm.norm.weight", "encoder.3.conv1.norm.norm.bias", "encoder.3.conv1.norm.norm.running_mean", "encoder.3.conv1.norm.norm.running_var", "encoder.3.conv1.norm.norm.num_batches_tracked", "encoder.3.conv2.conv2d.weight", "encoder.3.conv2.norm.norm.weight", "encoder.3.conv2.norm.norm.bias", "encoder.3.conv2.norm.norm.running_mean", "encoder.3.conv2.norm.norm.running_var", "encoder.3.conv2.norm.norm.num_batches_tracked", "encoder.4.shortcut_func.conv2d.weight", "encoder.4.shortcut_func.conv2d.bias", "encoder.4.conv1.conv2d.weight", "encoder.4.conv1.norm.norm.weight", "encoder.4.conv1.norm.norm.bias", "encoder.4.conv1.norm.norm.running_mean", "encoder.4.conv1.norm.norm.running_var", "encoder.4.conv1.norm.norm.num_batches_tracked", "encoder.4.conv2.conv2d.weight", "encoder.4.conv2.norm.norm.weight", "encoder.4.conv2.norm.norm.bias", "encoder.4.conv2.norm.norm.running_mean", "encoder.4.conv2.norm.norm.running_var", "encoder.4.conv2.norm.norm.num_batches_tracked", "body.0.conv1.conv2d.weight", "body.0.conv1.norm.norm.weight", "body.0.conv1.norm.norm.bias", "body.0.conv1.norm.norm.running_mean", "body.0.conv1.norm.norm.running_var", "body.0.conv1.norm.norm.num_batches_tracked", "body.0.conv2.conv2d.weight", "body.0.conv2.norm.norm.weight", "body.0.conv2.norm.norm.bias", "body.0.conv2.norm.norm.running_mean", "body.0.conv2.norm.norm.running_var", "body.0.conv2.norm.norm.num_batches_tracked", "body.1.conv1.conv2d.weight", "body.1.conv1.norm.norm.weight", "body.1.conv1.norm.norm.bias", "body.1.conv1.norm.norm.running_mean", "body.1.conv1.norm.norm.running_var", "body.1.conv1.norm.norm.num_batches_tracked", "body.1.conv2.conv2d.weight", "body.1.conv2.norm.norm.weight", "body.1.conv2.norm.norm.bias", "body.1.conv2.norm.norm.running_mean", "body.1.conv2.norm.norm.running_var", "body.1.conv2.norm.norm.num_batches_tracked", "body.2.conv1.conv2d.weight", "body.2.conv1.norm.norm.weight", "body.2.conv1.norm.norm.bias", "body.2.conv1.norm.norm.running_mean", "body.2.conv1.norm.norm.running_var", "body.2.conv1.norm.norm.num_batches_tracked", "body.2.conv2.conv2d.weight", "body.2.conv2.norm.norm.weight", "body.2.conv2.norm.norm.bias", "body.2.conv2.norm.norm.running_mean", "body.2.conv2.norm.norm.running_var", "body.2.conv2.norm.norm.num_batches_tracked", "body.3.conv1.conv2d.weight", "body.3.conv1.norm.norm.weight", "body.3.conv1.norm.norm.bias", "body.3.conv1.norm.norm.running_mean", "body.3.conv1.norm.norm.running_var", "body.3.conv1.norm.norm.num_batches_tracked", "body.3.conv2.conv2d.weight", "body.3.conv2.norm.norm.weight", "body.3.conv2.norm.norm.bias", "body.3.conv2.norm.norm.running_mean", "body.3.conv2.norm.norm.running_var", "body.3.conv2.norm.norm.num_batches_tracked", "body.4.conv1.conv2d.weight", "body.4.conv1.norm.norm.weight", "body.4.conv1.norm.norm.bias", "body.4.conv1.norm.norm.running_mean", "body.4.conv1.norm.norm.running_var", "body.4.conv1.norm.norm.num_batches_tracked", "body.4.conv2.conv2d.weight", "body.4.conv2.norm.norm.weight", "body.4.conv2.norm.norm.bias", "body.4.conv2.norm.norm.running_mean", "body.4.conv2.norm.norm.running_var", "body.4.conv2.norm.norm.num_batches_tracked", "body.5.conv1.conv2d.weight", "body.5.conv1.norm.norm.weight", "body.5.conv1.norm.norm.bias", "body.5.conv1.norm.norm.running_mean", "body.5.conv1.norm.norm.running_var", "body.5.conv1.norm.norm.num_batches_tracked", "body.5.conv2.conv2d.weight", "body.5.conv2.norm.norm.weight", "body.5.conv2.norm.norm.bias", "body.5.conv2.norm.norm.running_mean", "body.5.conv2.norm.norm.running_var", "body.5.conv2.norm.norm.num_batches_tracked", "body.6.conv1.conv2d.weight", "body.6.conv1.norm.norm.weight", "body.6.conv1.norm.norm.bias", "body.6.conv1.norm.norm.running_mean", "body.6.conv1.norm.norm.running_var", "body.6.conv1.norm.norm.num_batches_tracked", "body.6.conv2.conv2d.weight", "body.6.conv2.norm.norm.weight", "body.6.conv2.norm.norm.bias", "body.6.conv2.norm.norm.running_mean", "body.6.conv2.norm.norm.running_var", "body.6.conv2.norm.norm.num_batches_tracked", "body.7.conv1.conv2d.weight", "body.7.conv1.norm.norm.weight", "body.7.conv1.norm.norm.bias", "body.7.conv1.norm.norm.running_mean", "body.7.conv1.norm.norm.running_var", "body.7.conv1.norm.norm.num_batches_tracked", "body.7.conv2.conv2d.weight", "body.7.conv2.norm.norm.weight", "body.7.conv2.norm.norm.bias", "body.7.conv2.norm.norm.running_mean", "body.7.conv2.norm.norm.running_var", "body.7.conv2.norm.norm.num_batches_tracked", "body.8.conv1.conv2d.weight", "body.8.conv1.norm.norm.weight", "body.8.conv1.norm.norm.bias", "body.8.conv1.norm.norm.running_mean", "body.8.conv1.norm.norm.running_var", "body.8.conv1.norm.norm.num_batches_tracked", "body.8.conv2.conv2d.weight", "body.8.conv2.norm.norm.weight", "body.8.conv2.norm.norm.bias", "body.8.conv2.norm.norm.running_mean", "body.8.conv2.norm.norm.running_var", "body.8.conv2.norm.norm.num_batches_tracked", "body.9.conv1.conv2d.weight", "body.9.conv1.norm.norm.weight", "body.9.conv1.norm.norm.bias", "body.9.conv1.norm.norm.running_mean", "body.9.conv1.norm.norm.running_var", "body.9.conv1.norm.norm.num_batches_tracked", "body.9.conv2.conv2d.weight", "body.9.conv2.norm.norm.weight", "body.9.conv2.norm.norm.bias", "body.9.conv2.norm.norm.running_mean", "body.9.conv2.norm.norm.running_var", "body.9.conv2.norm.norm.num_batches_tracked", "decoder.0.shortcut_func.conv2d.weight", "decoder.0.shortcut_func.conv2d.bias", "decoder.0.conv1.conv2d.weight", "decoder.0.conv1.norm.norm.weight", "decoder.0.conv1.norm.norm.bias", "decoder.0.conv1.norm.norm.running_mean", "decoder.0.conv1.norm.norm.running_var", "decoder.0.conv1.norm.norm.num_batches_tracked", "decoder.0.conv2.conv2d.weight", "decoder.0.conv2.norm.norm.weight", "decoder.0.conv2.norm.norm.bias", "decoder.0.conv2.norm.norm.running_mean", "decoder.0.conv2.norm.norm.running_var", "decoder.0.conv2.norm.norm.num_batches_tracked", "decoder.1.shortcut_func.conv2d.weight", "decoder.1.shortcut_func.conv2d.bias", "decoder.1.conv1.conv2d.weight", "decoder.1.conv1.norm.norm.weight", "decoder.1.conv1.norm.norm.bias", "decoder.1.conv1.norm.norm.running_mean", "decoder.1.conv1.norm.norm.running_var", "decoder.1.conv1.norm.norm.num_batches_tracked", "decoder.1.conv2.conv2d.weight", "decoder.1.conv2.norm.norm.weight", "decoder.1.conv2.norm.norm.bias", "decoder.1.conv2.norm.norm.running_mean", "decoder.1.conv2.norm.norm.running_var", "decoder.1.conv2.norm.norm.num_batches_tracked", "decoder.2.shortcut_func.conv2d.weight", "decoder.2.shortcut_func.conv2d.bias", "decoder.2.conv1.conv2d.weight", "decoder.2.conv1.norm.norm.weight", "decoder.2.conv1.norm.norm.bias", "decoder.2.conv1.norm.norm.running_mean", "decoder.2.conv1.norm.norm.running_var", "decoder.2.conv1.norm.norm.num_batches_tracked", "decoder.2.conv2.conv2d.weight", "decoder.2.conv2.norm.norm.weight", "decoder.2.conv2.norm.norm.bias", "decoder.2.conv2.norm.norm.running_mean", "decoder.2.conv2.norm.norm.running_var", "decoder.2.conv2.norm.norm.num_batches_tracked", "decoder.3.shortcut_func.conv2d.weight", "decoder.3.shortcut_func.conv2d.bias", "decoder.3.conv1.conv2d.weight", "decoder.3.conv1.norm.norm.weight", "decoder.3.conv1.norm.norm.bias", "decoder.3.conv1.norm.norm.running_mean", "decoder.3.conv1.norm.norm.running_var", "decoder.3.conv1.norm.norm.num_batches_tracked", "decoder.3.conv2.conv2d.weight", "decoder.3.conv2.norm.norm.weight", "decoder.3.conv2.norm.norm.bias", "decoder.3.conv2.norm.norm.running_mean", "decoder.3.conv2.norm.norm.running_var", "decoder.3.conv2.norm.norm.num_batches_tracked", "out_img_conv.conv2d.weight", "out_img_conv.conv2d.bias", "out_mask_conv.conv2d.weight", "out_mask_conv.conv2d.bias".

By the way, when I run the generate_mask.py code directly using the FPN I trained as mentioned above, the generated image looks like this. image

Looking forward to your reply!

chaofengc commented 1 year ago

The model is a simple U-Net and should work with 128x128 inputs if you retrain it.

It seems that there are some problems in your code when loading the retrained parsing network. Possibly because you save the network without DataParallel while the default load_weight function expect a model wrapped with DataParallel.

I am not sure what is the exact problem. Sorry that I am not supposed to help you with the problem, you may use some debug tools like pdb to fix the bugs by yourself.

fbc-one commented 1 year ago

Thank you! I have successfully solved the problem. It was indeed caused by the mismatch of the network name saved in the pre-trained model. Currently, I am able to generate mask images with a shape of 128. Sincerely appreciate your reply!