syncdoth / face_mask_inpaint

GAN-based project of removing facial masks from face images.
1 stars 0 forks source link

face_mask_inpaint

This is a repository for "Reference Guided Facial Mask Removal" by Sehyun Choi and Minseok Oh, for the final project of HKUST COMP 4471, 2021 Fall.

Environment

We recommend using conda environment. First, create a conda environment using

conda create -n $env_name python=$py_ver
conda activate $env_name

Then, we have prepared a script for setting up the conda env at env_setup.

Experiments

First, download the CelebA and CelebAHQ dataset from the official project page.

We expect the dataset folder to be structured as:

CelebAHQ
├──images  # all source images
├──images_masked  # masked with MaskTheFace
├──images_masked_test  # pre selected test set
├──binary_map  # mask map .npy files
├──identity.txt

After you have downloaded the dataset, you need to download the pretrained models of PICNet and pSp. Follow the instructions at their original repos to download the pretrained weights. (for pSp, we used StyleGAN inversion checkpoint.)

Then, you could look at the scripts for the various training configurations.

Evaluation

For evaluation, we use SSIM, MS-SSIM, and FID. To obtain SSIM and MS-SSIM, run psp_inference.py or PICNet_inference.py depending on the model you want to test. An example for each is:

python psp_inference.py --use_ref --use_attention 1 \
--pt_ckpt_path saved_model/RefpSp_train_decoder_attention/G_checkpoint_epoch5.pth \
--batch_size 8 --data_root /path/to/CelebAHQ

python PICNet_inference.py \
--data_root /path/to/CelebAHQ  \
--src_img_path images_masked_test \
--ref_img_path images \
--mask_path binary_map \
--identity_file_path CelebA-HQ-identity.txt \
--mask_detector_path saved_model/new_mask_detector.pth \
--pt_ckpt_path saved_model/PICNet_best_ref_HQ_better_att/G_checkpoint_epoch4.pth \
--img_scale 0.25 \
--use_att 1 \
--batch_size 4 \
--decoder_img_f 256 --decoder_z_nc 256

These two scripts generates images from the test set, saves the results in test_results/[checkpoint_name] folder, and calculates SSIM and MS-SSIM. To calculate FID,

python -m pytorch_fid test_results/[checkpoint_name] path/to/test/images

Serving through gradio

run python gradio_serve.py with appropriate flags to serve using gradio.

Acknowledgements

This repo borrows heavily from other implementations. Namely:

License


This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.

This software is for educational and academic research purpose only.