Kunhee Kim, Sanghun Park, Eunyeong Jeon, Taehun Kim, Daijin Kim
POSTECH
Our model discovers various style prototypes from the dataset in a self-supervised manner. The style prototype consists of a combination of various attributes including (left) time, weather, season, and texture; and (right) age, gender, and accessories.
Paper: https://arxiv.org/abs/2203.15375
Abstract: Current image-to-image translations do not control the output domain beyond the classes used during training, nor do they interpolate between different domains well, leading to implausible results. This limitation largely arises because labels do not consider the semantic distance. To mitigate such problems, we propose a style-aware discriminator that acts as a critic as well as a style encoder to provide conditions. The style-aware discriminator learns a controllable style space using prototype-based self-supervised learning and simultaneously guides the generator. Experiments on multiple datasets verify that the proposed model outperforms current state-of-the-art image-to-image translation methods. In contrast with current methods, the proposed approach supports various applications, including style interpolation, content transplantation, and local image translation.
Clone this repository:
git clone https://github.com/kunheek/style-aware-discriminator.git
cd style-aware-discriminator
Then, install dependencies using anaconda or pip:
conda env create -f environment.yml
# or
pip install -r requirements.txt
We provide the following pre-trained networks.
Dataset | Resolution | Method | #images | |
---|---|---|---|---|
afhq-adain | AFHQ | $256^2$ | AdaIN | 1.6 M |
afhq-stylegan2 | AFHQ | $256^2$ | StyleGAN2 | 5 M |
afhqv2 | AFHQ v2 | $512^2$ | StyleGAN2 | 5 M |
celebahq-adain | CelebA-HQ | $256^2$ | AdaIN | 1.6 M |
celebahq-stylegan2 | CelebA-HQ | $256^2$ | StyleGAN2 | 5 M |
church | LSUN church | $256^2$ | StyleGAN2 | 25 M |
ffhq | FFHQ | $256^2$ | StyleGAN2 | 25 M |
flower | Oxford 102 | $256^2$ | AdaIN | 1.6 M |
We uploaded the checkpoints to HuggingFace. You can download them using the following command:
# download all checkpoints.
python download.py checkpoints
# download a specific checkpoint.
python download.py afhq-adain
See the table above or download.py for available checkpoints.
(Optional) Computing inception stats requires long time. We provide pre-calculated stats for AFHQ 256 and CelebA-HQ 256 datasets (link). You can download and register them using the following command:
python download.py stats
# python -m tools.register_stats PATH/TO/STATS
python -m tools.register_stats assets/stats
To evaluate our model run python -m metrics METRICS --checkpoint CKPT --train-dataset TRAINDIR --eval-dataset EVALDIR
. By default, all metrics will be saved in runs/{run-dir}/metrics.txt
. Available metrics are:
--eval-kid true
)--eval-mkid true
)See metrics/{task}_evaluator.py
for task specific options. You can parse multiple tasks at the same time. Here are some examples:
python -m metrics fid reconstruction --seed 123 --checkpoint ./checkpoints/afhq-stylegan2-5M.pt --train-dataset ./datasets/afhq/train --eval-dataset ./datasets/afhq/val
python -m metrics mean_fid --seed 777 --checkpoint ./checkpoints/celebahq-stylegan2-5M.pt --train-dataset ./datasets/celeba_hq/train --eval-dataset ./datasets/celeba_hq/val
You can synthesize images similarly to the quantitave evaluations (replace metrics
to synthesis
). By default, all images will be saved in runs/{run-dir}/{task}
folder.
# python -m synthesis [TASKS] --checkpoint PATH/TO/CKPT --folder PATH/TO/FOLDERS
python -m synthesis swap --checkpoint ./checkpoints/afhq-stylegan2-5M.pt --folder ./testphotos/afhq/content ./testphotos/afhq/style
python -m synthesis interpolation --checkpoint ./checkpoints/afhq-stylegan2-5M.pt --folder ./testphotos/afhq/content ./testphotos/afhq/style
Some tasks require multiple folders (e.g., content and style) or extra arguments. Available synthesis tasks are:
We provide additional tools for visualizing the learned style space:
python -m tools.plot_tsne --checkpoint checkpoints/afhq-stylegan2-5M.pt --target-dataset datasets/afhq/val --seed 7 --title AFHQ --labels cat dog wild
python -m tools.plot_tsne --checkpoint checkpoints/celebahq-stylegan2-5M.pt --target-dataset datasets/celeba_hq/val --seed 7 --title CelebA-HQ --legends female male
python -m tools.similarity_search --checkpoint CKPT --query QUERY_IMAGE --target-dataset TESTDIR
By default, all images in the folder will be used for training or evaluation (supported image formats can be found here). For example, if you parse --train-dataset=./datasets/afhq/train
, all images in the ./datasets/afhq/train
folder will be used for training.
For LSUN datasets, lsun
must be included in the folder path.
datasets
└─ lsun
├─ church_outdoor_train_lmdb
└─ church_outdoor_val_lmdb
To measure mean fid
, a subdirectory corresponding to each class must exist (less than 5). If you want to reproduce experiments in the paper, we recommend to use the following structure:
datasets
├─ afhq
│ ├─ train
│ │ ├─ cat
│ │ ├─ dog
│ │ └─ wild
│ └─ val (or test)
│ └─ (cat/dog/wild)
└─ celeba_hq
├─ train
│ ├─ female
│ └─ male
└─ val
└─ (female/male)
Notice: We recommend training networks on a single GPU with enough memory (e.g., A100) to obtain best results, since we observed performance degradation with current implementation when using multiple GPUs (DDP). For example, a model trained on a A100 GPU (40GB) is slightly better than a model trained on two TITAN XP GPU (12GB * 2). We used a single NVIDIA A100 GPU for AFHQ and CelebA-HQ experiments and four NVIDIA RTX3090 GPUs for AFHQ v2, LSUN churches, and FFHQ experiments. Note that we disabled tf32 for all experiments.
We provide training scripts here. Use the following commands to train networks with custom arguments:
# Single GPU training.
python train.py --mod-type adain --total-nimg 1.6M --batch-size 16 --load-size 320 --crop-size 256 --image-size 256 --train-dataset datasets/afhq/train --eval-dataset datasets/afhq/val --out-dir runs --extra-desc some descriptions
# Multi-GPU training.
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 train.py --total-nimg 25M --batch-size 64 --load-size 320 --crop-size 256 --image-size 256 --train-dataset datasets/ffhq/images1024x1024 --eval-dataset datasets/ffhq/images1024x1024 --nb-proto 128 --latent-dim 512 --latent-ratio 0.5 --jitter true --cutout true --out-dir runs --extra-desc some descriptions
Training options, codes, checkpoints, and snapshots will be saved in the {out-dir}/{run-id}-{dataset}-{resolution}-{extra-desc}
. Please see train.py, model.py, and augmentation.py for available arguments.
To resume training, run python train.py --resume PATH/TO/RUNDIR
. For example:
# Single GPU training.
python train.py --resume runs/000-afhq-256x256-some-discriptions
# Multi-GPU training.
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 python train.py --resume runs/001-ffhq-some-discriptions
If you find this repository useful for your research, please cite our paper:
@InProceedings{kim2022style,
title={A Style-Aware Discriminator for Controllable Image Translation},
author={Kim, Kunhee and Park, Sanghun and Jeon, Eunyeong and Kim, Taehun and Kim, Daijin},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2022},
pages={18239--18248}
}
Many of our implementations are adapted from previous works, including SwAV, DINO, StarGAN v2, Swapping Autoencoder, clean-fid, and stylegan2-pytorch.
All materials except custom CUDA kernels in this repository are made available under the MIT License.
The custom CUDA kernels (fused_bias_act_kernel.cu and upfirdn2d_kernel.cu) are under the Nvidia Source Code License, and are for non-commercial use only.