ewrfcas / MST_inpainting

Learning a Sketch Tensor Space for Image Inpainting of Man-made Scenes (ICCV 2021)
MIT License
84 stars 11 forks source link

Learning a Sketch Tensor Space for Image Inpainting of Man-made Scenes (ICCV 2021)

Chenjie Cao, Yanwei Fu

LICENSE

teaser arXiv | Project Page

Overview

teaser We learn an encoder-decoder model, which encodes a Sketch Tensor (ST) space consisted of refined lines and edges. Then the model recover the masked images by the ST space.

News

Now, this work has been further improved in ZITS (CVPR2022).

Preparation

  1. Preparing the environment.
  2. Download the pretrained masked wireframe detection model LSM-HAWP (retrained from HAWP CVPR2020).
  3. Download weights for different requires to the 'check_points' fold. P2M (Man-made Places2), P2C (Comprehensive Places2), shanghaitech (Shanghaitech with all man-made scenes).
  4. For training, we provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training (flist_example.txt).

Training

Since the training code is rewritten, there are some differences compared with the test code.

  1. Training uses src/models.py while testing uses src/model_inference.py.

  2. Image are valued in -1 to 1 (training) and 0 to 1 (testing).

  3. Masks are always concated to the inputs.

  1. Generating wireframes by lsm-hawp.

    CUDA_VISIBLE_DEVICES=0 python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path>
  2. Setting file lists in training_configs/config_MST.yml (example: flist_example.txt).

  3. Train the inpainting model with stage1 and stage2.

    python train_MST_stage1.py --path <model_name> --config training_configs/config_MST.yml --gpu 0
    python train_MST_stage2.py --path <model_name> --config training_configs/config_MST.yml --gpu 0

    For DDP training with multi-gpus:

    python -m torch.distributed.launch --nproc_per_node=4 train_MST_stage1.py --path <model_name> --config training_configs/config_MST.yml --gpu 0,1,2,3
    python -m torch.distributed.launch --nproc_per_node=4 train_MST_stage2.py --path <model_name> --config training_configs/config_MST.yml --gpu 0,1,2,3

Test for a single image

python test_single.py --gpu_id 0 \
                      --PATH ./check_points/MST_P2C \
                      --image_path <your image path> \
                      --mask_path <your mask path (0 means valid and 255 means masked)>

Object Removal Examples

Object removal video

Comparisons

ST Places2