drboog / Shifted_Diffusion

Code for Shifted Diffusion for Text-to-image Generation (CVPR 2023)
Creative Commons Zero v1.0 Universal
159 stars 11 forks source link

Shifted Diffusion for Text-to-image Generation

examples

Code for Shifted Diffusion for Text-to-image generation (CVPR 2023).

Shifted Diffusion is a new diffusion model designed to better generate image embeddings from text.

framework

("Decoder" can be either diffusion-based or GAN-based model, you can also make it conditioned on both image embedding and text.)

With Shifted Diffusion, you can

Below we provide examples of using our Shifted Diffusion.

Don't forget to create a new conda environment in advance.

Get started

Install some dependencies

pip install -r ./requirements.txt
pip install git+https://github.com/openai/CLIP.git
cd ./diffusers
pip install -e .
cd ..
wget "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"
wget "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
wget "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"
accelerate config

To train a Shifted Diffusion model, run (choose hyper-parameters based on your device)

accelerate launch --mixed_precision="fp16" train.py

We provide our pre-trained Shifted Diffusion models here.

Shifted Diffusion + Stable Diffusion

We provide a simple example which combines our pre-trained Shifted Diffusion with Stable Diffusion 2.

Specifically, a projection layer is added, which maps input image embedding into 4 word embeddings. Feel free to try more complicated architectures.

With the example below, one can first fine-tune a Stable Diffusion model on image-only dataset (language-free setting), then

Fine-tune a Stable Diffusion model

Prepare an image-only dataset (MS-COCO for example)

wget http://images.cocodataset.org/zips/train2014.zip
unzip train2014.zip
python process_img.py --src=./train2014 --size=512 --dest=./train2014

Run

accelerate launch --mixed_precision="fp16" finetune.py\
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-base" \
  --train_data_dir=./train2014/ \
  --use_ema \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=8 \
  --gradient_accumulation_steps=1 \
  --gradient_checkpointing \
  --max_train_steps=30000 \
  --checkpointing_steps=5000\
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir="./finetuned_coco"

(We did not optimize hyper-parameters, hyper-parameters follow examples here)

Here are some slightly fine-tuned Stable Diffusion 2 models, we used a total batch size of 8 8 1 = 64.

(The models are "slightly fine-tuned", which means we only fine-tuned them for 10k~30k steps, just for example purpose. More fine-tuning steps with better tuned hyper-parameters will lead to better results.)

Test fine-tuned Stable Diffusion model

Generate image with CLIP image embedding

Run

python test.py

Examples of input/generated images on different datasets:

pelican re_pelicantrain re_train
face re_faceface re_face



Generate image with text + Shifted Diffusion

Run

python sft_test.py

Below we provide a comparison.

yellow-and-blue-train

A ground-truth image-text pair is shown, obtained from MS-COCO dataset.

Although Stable Diffusion 2 is able to perform zero-shot generation, the generation may not satisfy our requirement in terms of style, etc.

With our language-free fine-tuning and pre-trained Shifted Diffusion model, we are able to generate desired images.

This approach can be easily applied to different domains/datasets, no image-text pair is needed in fine-tuning.

Below is a comparison between shifted diffusion and baseline diffusion on fine-tuned Stable Diffusion 2 model, where we evaluate the FID score and CLIP similarity (average similarity from CLIP ViT-B/16, ViT-B/32, RN-101) between generated images with input text/ground-truth target images.

FID-CLIP-img-img FID-CLIP-img-text

Shifted Diffusion + Lafite

The decoder can also be GAN-based models, e.g. Lafite.

Similar to the example above, one need to construct an image-only dataset, then train a mapping which maps image embeddings to images.

After training of GAN, directly utilize pre-trained Shifted Diffusion model to perform text-to-image generation at inference.

Citation

@article{zhou2022shifted,
  title={Shifted Diffusion for Text-to-image Generation},
  author={Zhou, Yufan and Liu, Bingchen and Zhu, Yizhe and Yang, Xiao and Chen, Changyou and Xu, Jinhui},
  journal={arXiv preprint arXiv:2211.15388},
  year={2022}
}