This is the official codebase for Unsupervised Compositional Image Decomposition with Diffusion Models.
[Compositional Image Decomposition with Diffusion Models]()
Jocelin Su 1,
Nan Liu 2,
Yanbo Wang 3,
Joshua B. Tenenbaum 1,
Yilun Du 1,
Equal Contribution
1MIT, 2UIUC, 3 TU Delft
The demo notebook shows how to train a model and perform experiments on decomposition, reconstruction, and recombination of factors on CLEVR, as well as recombination in multi-modal and cross-dataset settings.
Run the following to create and activate a conda environment:
conda create -n decomp_diff python=3.8
conda activate decomp_diff
To install this package, clone this repository and then run:
pip install -e .
We use a U-Net model architecture. To train a model, we specify its parameters and training parameters as follows:
MODEL_FLAGS="--emb_dim 64 --enc_channels 128"
TRAIN_FLAGS="--batch_size 16 --dataset clevr --data_dir ../"
For distributed training, we run the following:
DEVICE=$CUDA_VISIBLE_DEVICES
NUM_DEVICES=$(echo $CUDA_VISIBLE_DEVICES | tr ',' '\n' | wc -l)
python -m torch.distributed.run --nproc_per_node=$NUM_DEVICES scripts/image_train.py $MODEL_FLAGS $TRAIN_FLAGS
Otherwise, we run:
python scripts/image_train.py $MODEL_FLAGS $TRAIN_FLAGS --use_dist False
To generate images, we use a trained model and run a sampling loop, where DDPM sampling or DDIM sampling is specified. We provide pre-trained models for various datasets below. For example, a pre-trained CLEVR model is provided here.
To perform decomposition and reconstruction of an input image, run the following:
MODEL_CHECKPOINT="clevr_model.pt"
MODEL_FLAGS="--emb_dim 64 --enc_channels 128"
python scripts/gen_image_script.py --dataset clevr --ckpt_path $MODEL_CHECKPOINT $MODEL_FLAGS --im_path sample_images/clevr_im_10.png --save_dir gen_clevr_img/ --sample_method ddim
In addition, we can generate results for multiple images in a dataset:
python scripts/gen_image_script.py --gen_images 100 --dataset $DATASET --ckpt_path $MODEL_CHECKPOINT $MODEL_FLAGS --save_dir gen_many_clevr_imgs/
Decomp Diffusion can also compose discovered factors. To combine factors across 2 images, run:
python scripts/gen_image_script.py --combine_method slice --dataset $DATASET --ckpt_path $MODEL_CHECKPOINT $MODEL_FLAGS --im_path $IM_PATH --im_path2 $IM_PATH2 --save_dir $SAVE_DIR
See gen_image_script.py
for additional options such as generating additive combinations or cross-dataset combinations.
See our paper for details on training datasets. Note that Tetris images are 32x32 instead of 64x64.
Dataset | Link |
---|---|
CLEVR | Link |
CLEVR Toy | Link |
Tetris | Link |
CelebA-HQ 128x128 | Link |
KITTI | Link |
Virtual KITTI 2 | Link |
Falcor3D | Link |
Anime | Link |
See our paper for details on model parameters for each dataset. We provide links to pre-trained models below, as well as their non-default parameter flags. We used --batch_size 32
during training.
Model | Link | Model Flags |
---|---|---|
CLEVR | Link | --emb_dim 64 --enc_channels 128 |
CelebA-HQ | Link | --enc_channels 128 |
Faces | Link | --enc_channels 128 |
CLEVR Toy | Link | --emb_dim 64 --enc_channels 128 |
Tetris | --image_size 32 --num_components 3 --num_res_blocks 1 --enc_channels 64 |
|
VKITTI | --num_channels 64 --enc_channels 64 --emb_dim 256 |
|
Combined KITTI | --num_channels 64 --enc_channels 64 --emb_dim 256 |
|
Falcor3D | --num_channels 64 --emb_dim 32 --channel_mult 1,2 |