Official Implementation of Denoising Diffusion Bridge Models.
To install all packages in this codebase along with their dependencies, run
pip install torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
pip install packaging ninja
conda install -c conda-forge mpi4py openmpi
pip install -e .
We provide pretrained checkpoints via Huggingface repo here. It includes models trained on two image-to-image datasets using Variance-Preserving (VP) schedules:
For Edges2Handbags, please follow instructions from here. For DIODE, please download appropriate datasets from here.
We provide bash files train_ddbm.sh and sample_ddbm.sh for model training and sampling.
Simply set variables DATASET_NAME
and SCHEDULE_TYPE
:
DATASET_NAME
specifies which dataset to use. We only support e2h
for Edges2Handbags and diode
for DIODE. For each dataset, make sure to set the respective DATA_DIR
variable in args.sh
to your dataset path.SCHEDULE_TYPE
denotes the noise schedule type. Only ve
and vp
are recommended. ve_simple
and vp_simple
are their naive baselines.To train, run
bash train_ddbm.sh $DATASET_NAME $SCHEDULE_TYPE
# to resume, set CKPT to your checkpoint, or it will automatically resume from your last checkpoint based on your experiment name.
bash train_ddbm.sh $DATASET_NAME $SCHEDULE_TYPE $CKPT
For inference, additional variables need to be set:
MODEL_PATH
is your checkpoint to be evaluated.CHURN_STEP_RATIO
is the ratio of step that's used for stochastic Euler step (see paper for details). Default recommendation is 0.33
. Lower value generally degrades performance. For better value setting please refer to the paper.GUIDANCE
is the w
parameter specified in the paper. Default recommendation is 1
for VP schedules and anything less than 1
produces significantly worse results. However, for VE schedules, this value (ranging from 0
to 1
) does not affect generation too much. . For better value setting please refer to the paper.SPLIT
denotes which split you use for testing. Only train
and test
are supported.
To sample, run
bash sample_ddbm.sh $DATASET_NAME $SCHEDULE_TYPE $MODEL_PATH $CHURN_STEP_RATIO $GUIDANCE $SPLIT
This script will aggregate all samples into .npz
file into your experiment folder ready for quantitative evaluation.
One can evaluate samples with evaluations/evaluator.py. We also provide the reference statistics in our Huggingface repo:
To evaluate, set REF_PATH
to path of your reference stats and SAMPLE_PATH
to your generated .npz
path. You can additionally specify the metrics to use via --metric
. We only support fid
and lpips
.
python $REF_PATH $SAMPLE_PATH --metric $YOUR_METRIC
We noticed that on some machines mpiexec
errors out with
--------------------------------------------------------------------------
MPI_INIT has failed because at least one MPI process is unreachable
from another. This *usually* means that an underlying communication
plugin -- such as a BTL or an MTL -- has either not loaded or not
allowed itself to be used. Your MPI job will now abort.
You may wish to try to narrow down the problem;
* Check the output of ompi_info to see which BTL/MTL plugins are
available.
* Run your application with MPI_THREAD_SINGLE.
* Set the MCA parameter btl_base_verbose to 100 (or mtl_base_verbose,
if using MTL-based communications) to see exactly which
communication plugins were considered and/or discarded.
--------------------------------------------------------------------------
In this case, you can try adding --mca btl vader,self
to mpiexec
command before python
run.
During evaluation, if you see significantly high LPIPS or MSE scores, this is likely due to mismatch in order between your generation and the reference stats. This may be due to the multiprocess gathering of results returning the incorrect order. Please make sure the order is correct for your generation, or regenerate the reference stats by yourself.
If you find this method and/or code useful, please consider citing
@article{zhou2023denoising,
title={Denoising diffusion bridge models},
author={Zhou, Linqi and Lou, Aaron and Khanna, Samar and Ermon, Stefano},
journal={arXiv preprint arXiv:2309.16948},
year={2023}
}