This repo contains the PyTorch implementation of "Attention-Based Spatial Guidance for Image-to-Image Translation" (WACV 2021).\
This implementation is based on the official CycleGAN code.\
Our model and contributions are in ./models/attn_cycle_gan_model.py
and ./models/attn_cycle_gan_v2_model.py
, respectively.
Clone this repo:
git clone https://github.com/voidstrike/ASGIT
Install the requirements based on the official CycleGAN repo (Visdom and Dominate). You can install all of them by one command:
pip install -r requirements.txt
For CONDA users, you can use script ./scripts/conda_dep.sh
to install pytorch and other libraries (NOT tested)
Cityscapes dataset can be downloaded from Cityscapes. You must register an account to access and download the dataset.
ImageNet based datasets like apple2orange, horse2zebra and .etc can be downloaded using ./scripts/download_cyclegan_model.sh
.
Day2Night dataset can be downloaded from here. Please note that those street images are cropped from BDD100K dataset.
Train a model (The following commands are used for PHA+RHP, PHA+Alpha, TAM+RHP, TAM+Alpha, respectively)
python3 train_attn.py --netG resnet_9blocks --netD posthoc_attn --model attn_cycle_gan --concat rmult --dataroot DATASETPATH --name EXP_NAME
python3 train_attn.py --netG resnet_9blocks --netD posthoc_attn_v2 --model attn_cycle_gan_v2 --concat alpha --dataroot DATASETPATH --name EXP_NAME
python3 train_attn.py --netG resnet_9blocks --netD trainable_attn --model attn_cycle_gan --concat rmult --dataroot DATASETPATH --name EXP_NAME
python3 train_attn.py --netG resnet_9blocks --netD trainable_attn_v2 --model attn_cycle_gan_v2 --concate alpha --dataroot DATASETPATH --name EXP_NAME
There are more hyper parameter options, please refer to the source code for more detail.
Please modify DATASETPATH and EXP_NAME accordingly.
To view training results and log plots, please run python -m visdom.server
and go to URL http://localhost:8097.
To see more intermediate results, check out ./checkpoints/EXP_NAME/web/index.html
Test a model (Please make sure --netG, --netD, --model and --concat are consistent with the training command)
python3 test.py --netG resnet_9blocks --netD posthoc_attn --model attn_cycle_gan --concat rmult --dataroot DATASETPATH
python3 test.py --netG resnet_9blocks --netD posthoc_attn_v2 --model attn_cycle_gan_v2 --concat alpha --dataroot DATASETPATH
python3 test.py --netG resnet_9blocks --netD trainable_attn --model attn_cycle_gan --concat rmult --dataroot DATASETPATH
python3 test.py --netG resnet_9blocks --netD trainable_attn_v2 --model attn_cycle_gan_v2 --concate alpha --dataroot DATASETPATH
We provide some translation results of our model.
If you use this code or dataset for your research, please consider cite our paper:
@InProceedings{Lin_2021_WACV,
author = {Lin, Yu and Wang, Yigong and Li, Yifan and Gao, Yang and Wang, Zhuoyi and Khan, Latifur},
title = {Attention-Based Spatial Guidance for Image-to-Image Translation},
booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
month = {January},
year = {2021},
pages = {816-825}
}
--netD
, a specified network without suffix _v2
(e.g. posthoc_attn
) means the attention mask will be normalized into [0, 1]. Thus, a model name w/ and w/o the _v2
suffix are exchangable;--mask_size IMG_SIZE
to the shape of your input;attn_cycle_gan_model.py
accordingly