We learn an encoder-decoder model, which encodes a Sketch Tensor (ST) space consisted of refined lines and edges. Then the model recover the masked images by the ST space.
Now, this work has been further improved in ZITS (CVPR2022).
Since the training code is rewritten, there are some differences compared with the test code.
Training uses src/models.py while testing uses src/model_inference.py.
Image are valued in -1 to 1 (training) and 0 to 1 (testing).
Masks are always concated to the inputs.
Generating wireframes by lsm-hawp.
CUDA_VISIBLE_DEVICES=0 python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path>
Setting file lists in training_configs/config_MST.yml (example: flist_example.txt).
Train the inpainting model with stage1 and stage2.
python train_MST_stage1.py --path <model_name> --config training_configs/config_MST.yml --gpu 0
python train_MST_stage2.py --path <model_name> --config training_configs/config_MST.yml --gpu 0
For DDP training with multi-gpus:
python -m torch.distributed.launch --nproc_per_node=4 train_MST_stage1.py --path <model_name> --config training_configs/config_MST.yml --gpu 0,1,2,3
python -m torch.distributed.launch --nproc_per_node=4 train_MST_stage2.py --path <model_name> --config training_configs/config_MST.yml --gpu 0,1,2,3
python test_single.py --gpu_id 0 \
--PATH ./check_points/MST_P2C \
--image_path <your image path> \
--mask_path <your mask path (0 means valid and 255 means masked)>