jsu27 / decomp_diffusion

[ICML 2024] Compositional Image Decomposition with Diffusion Models
38 stars 5 forks source link

Compositional Image Decomposition with Diffusion Models

We propose Decomp Diffusion, an unsupervised approach that discovers compositional concepts from images, represented by diffusion models.

Project Page | Paper | Google Colab | Huggingface |


This is the official codebase for Unsupervised Compositional Image Decomposition with Diffusion Models.

[Compositional Image Decomposition with Diffusion Models]()
Jocelin Su 1, Nan Liu 2, Yanbo Wang 3, Joshua B. Tenenbaum 1, Yilun Du 1,
Equal Contribution
1MIT, 2UIUC, 3 TU Delft


The demo notebook shows how to train a model and perform experiments on decomposition, reconstruction, and recombination of factors on CLEVR, as well as recombination in multi-modal and cross-dataset settings.

Setup

Run the following to create and activate a conda environment:

conda create -n decomp_diff python=3.8
conda activate decomp_diff

To install this package, clone this repository and then run:

pip install -e .

Training

We use a U-Net model architecture. To train a model, we specify its parameters and training parameters as follows:

MODEL_FLAGS="--emb_dim 64 --enc_channels 128"
TRAIN_FLAGS="--batch_size 16 --dataset clevr --data_dir ../"

For distributed training, we run the following:

DEVICE=$CUDA_VISIBLE_DEVICES
NUM_DEVICES=$(echo $CUDA_VISIBLE_DEVICES | tr ',' '\n' | wc -l)

python -m torch.distributed.run --nproc_per_node=$NUM_DEVICES scripts/image_train.py $MODEL_FLAGS $TRAIN_FLAGS

Otherwise, we run:

python scripts/image_train.py $MODEL_FLAGS $TRAIN_FLAGS --use_dist False

Inference

To generate images, we use a trained model and run a sampling loop, where DDPM sampling or DDIM sampling is specified. We provide pre-trained models for various datasets below. For example, a pre-trained CLEVR model is provided here.

To perform decomposition and reconstruction of an input image, run the following:

MODEL_CHECKPOINT="clevr_model.pt"
MODEL_FLAGS="--emb_dim 64 --enc_channels 128"
python scripts/gen_image_script.py --dataset clevr --ckpt_path $MODEL_CHECKPOINT $MODEL_FLAGS --im_path sample_images/clevr_im_10.png --save_dir gen_clevr_img/ --sample_method ddim

In addition, we can generate results for multiple images in a dataset:

python scripts/gen_image_script.py --gen_images 100 --dataset $DATASET --ckpt_path $MODEL_CHECKPOINT $MODEL_FLAGS --save_dir gen_many_clevr_imgs/

Decomp Diffusion can also compose discovered factors. To combine factors across 2 images, run:

python scripts/gen_image_script.py --combine_method slice --dataset $DATASET --ckpt_path $MODEL_CHECKPOINT $MODEL_FLAGS --im_path $IM_PATH --im_path2 $IM_PATH2 --save_dir $SAVE_DIR 

See gen_image_script.py for additional options such as generating additive combinations or cross-dataset combinations.


Datasets

See our paper for details on training datasets. Note that Tetris images are 32x32 instead of 64x64.

Dataset Link
CLEVR Link
CLEVR Toy Link
Tetris Link
CelebA-HQ 128x128 Link
KITTI Link
Virtual KITTI 2 Link
Falcor3D Link
Anime Link

Models

See our paper for details on model parameters for each dataset. We provide links to pre-trained models below, as well as their non-default parameter flags. We used --batch_size 32 during training.

Model Link Model Flags
CLEVR Link --emb_dim 64 --enc_channels 128
CelebA-HQ Link --enc_channels 128
Faces Link --enc_channels 128
CLEVR Toy Link --emb_dim 64 --enc_channels 128
Tetris --image_size 32 --num_components 3 --num_res_blocks 1 --enc_channels 64
VKITTI --num_channels 64 --enc_channels 64 --emb_dim 256
Combined KITTI --num_channels 64 --enc_channels 64 --emb_dim 256
Falcor3D --num_channels 64 --emb_dim 32 --channel_mult 1,2