Dual Diffusion Implicit Bridges for Image-to-Image Translation
Xuan Su,
Jiaming Song,
Chenlin Meng,
Stefano Ermon
ICLR '23 |
GitHub | arXiv
| Colab
| Project Page
Common image-to-image translation methods rely on joint training over data from both source and target domains. The training process requires concurrent access to both datasets, which hinders data separation and privacy protection; and existing models cannot be easily adapted to translation of new domain pairs. We present Dual Diffusion Implicit Bridges (DDIBs), an image translation method based on diffusion models, that circumvents training on domain pairs. Image translation with DDIBs relies on two diffusion models trained independently on each domain, and is a two-step process: DDIBs first obtain latent encodings for source images with the source diffusion model, and then decode such encodings using the target model to construct target images. Both steps are defined via ordinary differential equations (ODEs), thus the process is cycle consistent only up to discretization errors of the ODE solvers. Theoretically, we interpret DDIBs as concatenation of source to latent, and latent to target Schrödinger Bridges, a form of entropy-regularized optimal transport, to explain the efficacy of the method. Experimentally, we apply DDIBs on synthetic and high-resolution image datasets, to demonstrate their utility in a wide variety of translation tasks and their inherent optimal transport properties.
Installation follows the same procedures as in the above repositories.
We first install the current repository, and then install other libraries like numpy, matplotlib
etc. My successful
installation contains the following version numbers, with Python 3.9:
pip install -e .
pip install numpy==1.24.0 matplotlib==3.6.2 scikit-image==0.19.3 scikit-learn==1.2.0 gdown==4.6.0
conda install -c conda-forge mpi4py openmpi
We release pretrained checkpoints for the 2D synthetic models in the paper.
Downloading via script: In your repository, run python download.py --exp synthetic
to download the pretrained
synthetic models. The
script will create a directory models/synthetic
and automatically download the checkpoints to the directory.
Downloading manually: As an alternative, you can also download the checkpoint manually. Here is the download link for the model checkpoints: Synthetic Models
Indexes. We use indexes 0-5 to refer to the 6 synthetic types, in: [Moons, Checkerboards, Concentric Rings, Concentric Squares, Parallel Rings, Parallel Squares].
How are the datasets generated? The key file to look at is: guided_diffusion/synthetic_datasets.py
. We implement
the data generation and sampling processes for various 2D modalities.
After running the download script, we can run the cycle consistency, synthetic translation and sampling experiments below.
python scripts/synthetic_train.py --num_res_blocks 3 --diffusion_steps 4000 --noise_schedule linear --lr 1e-4 --batch_size 20000 --task 0
Task is an integer in {0, 1, 2, 3, 4, 5} corresponding to one of the synthetic types.
Training each model probably takes only a few (3-4) hours on a GPU.
The models and logs are saved to the directory at OPENAI_LOGDIR
. If you want to save the model files to your desired
folder, modify the variable via export OPENAI_LOGDIR=...
python scripts/synthetic_cycle.py --num_res_blocks 3 --diffusion_steps 4000 --batch_size 30000 --source 0 --target 1
The above command runs the cycle-consistent translation experiment in the paper: between datasets Moons (0) and
Checkerboards (1). The generated experiment plots are saved under the new directory experiments/images
.
python scripts/synthetic_translation.py --num_res_blocks 3 --diffusion_steps 4000 --batch_size 30000 --source 0 --target 3
The above command performs translation between the two synthetic domains and saves the resulting plots
to experiments/images
.
python scripts/synthetic_sample.py --num_res_blocks 3 --diffusion_steps 4000 --batch_size 20000 --num_samples 80000 --task 1
ModuleNotFoundError: No module named 'guided_diffusion'
, here is one possible
solution: https://stackoverflow.com/a/23210066/3284573.Download the model weights. Similarly, run python download.py --exp imagenet
to download the pretrained,
class-conditional ImageNet models from guided-diffusion. The script will
create a directory models/imagenet
and put the classifier and diffusion model weights there.
Copy the validation dataset. We use the ImageNet validation set for domain translation. Two steps:
val_dir
, as we'll need it for the next command.We are now ready to translate the images.
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python scripts/imagenet_translation.py $MODEL_FLAGS --classifier_scale 1.0 --source 260,261,282,283 --target 261,262,283,284 --val_dir path_to_val_set
The above command copies the validation images for the source classes to ./experiments/imagenet
. Translated images are
placed in the same folder, with the target class appended in the filename.
We can update source
and target
to translate between other ImageNet classes. The corresponding val images are copied
automatically.
We translate the domain pairs in the specified order. For example, in the above command, we translate from class 260 to 261, 283 to 284, etc.
We can experiment with classifier_scale
to guide the denoising process towards the target class with different
strengths.
We can prepend the Python command with mpiexec -n N
to run it over multiple GPUs. For details, refer
to guided-diffusion.
@inproceedings{
su2022dual,
title={Dual Diffusion Implicit Bridges for Image-to-Image Translation},
author={Su, Xuan and Song, Jiaming and Meng, Chenlin and Ermon, Stefano},
booktitle={International Conference on Learning Representations},
year={2023},
}
This implementation is based on / inspired by: OpenAI: openai/guided-diffusion and openai/improved-diffusion.