alexzhou907 / DDBM

152 stars 15 forks source link

Denoising Diffusion Bridge Models (ICLR 2024)

Official Implementation of Denoising Diffusion Bridge Models.

Dependencies

To install all packages in this codebase along with their dependencies, run

pip install torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
pip install packaging ninja
conda install -c conda-forge mpi4py openmpi
pip install -e .

Pre-trained models

We provide pretrained checkpoints via Huggingface repo here. It includes models trained on two image-to-image datasets using Variance-Preserving (VP) schedules:

Datasets

For Edges2Handbags, please follow instructions from here. For DIODE, please download appropriate datasets from here.

Model training and sampling

We provide bash files train_ddbm.sh and sample_ddbm.sh for model training and sampling.

Simply set variables DATASET_NAME and SCHEDULE_TYPE:

To train, run

bash train_ddbm.sh $DATASET_NAME $SCHEDULE_TYPE 

# to resume, set CKPT to your checkpoint, or it will automatically resume from your last checkpoint based on your experiment name.

bash train_ddbm.sh $DATASET_NAME $SCHEDULE_TYPE $CKPT

For inference, additional variables need to be set:

Evaluations

One can evaluate samples with evaluations/evaluator.py. We also provide the reference statistics in our Huggingface repo:

To evaluate, set REF_PATH to path of your reference stats and SAMPLE_PATH to your generated .npz path. You can additionally specify the metrics to use via --metric. We only support fid and lpips.

python $REF_PATH $SAMPLE_PATH --metric $YOUR_METRIC

Toubleshoot

We noticed that on some machines mpiexec errors out with

--------------------------------------------------------------------------
MPI_INIT has failed because at least one MPI process is unreachable
from another.  This *usually* means that an underlying communication
plugin -- such as a BTL or an MTL -- has either not loaded or not
allowed itself to be used.  Your MPI job will now abort.

You may wish to try to narrow down the problem;  

 * Check the output of ompi_info to see which BTL/MTL plugins are
   available.
 * Run your application with MPI_THREAD_SINGLE.  
 * Set the MCA parameter btl_base_verbose to 100 (or mtl_base_verbose,
   if using MTL-based communications) to see exactly which
   communication plugins were considered and/or discarded.
--------------------------------------------------------------------------

In this case, you can try adding --mca btl vader,self to mpiexec command before python run.

During evaluation, if you see significantly high LPIPS or MSE scores, this is likely due to mismatch in order between your generation and the reference stats. This may be due to the multiprocess gathering of results returning the incorrect order. Please make sure the order is correct for your generation, or regenerate the reference stats by yourself.

Citation

If you find this method and/or code useful, please consider citing

@article{zhou2023denoising,
  title={Denoising diffusion bridge models},
  author={Zhou, Linqi and Lou, Aaron and Khanna, Samar and Ermon, Stefano},
  journal={arXiv preprint arXiv:2309.16948},
  year={2023}
}