salesforce / UniControl

Unified Controllable Visual Generation Model
https://canqin001.github.io/UniControl-Page/
Apache License 2.0
623 stars 35 forks source link
aigc generation multi-modal

UniControl arXiv webpage HuggingFace space

This repository is for the paper:

UniControl: A Unified Diffusion Model for Controllable Visual Generation In the Wild \ Can Qin 1,2, Shu Zhang1, Ning Yu 1, Yihao Feng1, Xinyi Yang1, Yingbo Zhou 1, Huan Wang 1, Juan Carlos Niebles1, Caiming Xiong 1, Silvio Savarese 1, Stefano Ermon 3, Yun Fu 2, Ran Xu 1 \ 1 Salesforce AI 2 Northeastern University 3 Stanford University \ Work done when Can Qin was an intern at Salesforce AI Research.

img

Introduction

We introduce UniControl, a new generative foundation model that consolidates a wide array of controllable condition-to-image (C2I) tasks within a singular framework, while still allowing for arbitrary language prompts. UniControl enables pixel-level-precise image generation, where visual conditions primarily influence the generated structures and language prompts guide the style and context. To equip UniControl with the capacity to handle diverse visual conditions, we augment pretrained text-to-image diffusion models and introduce a task-aware HyperNet to modulate the diffusion models, enabling the adaptation to different C2I tasks simultaneously. Experimental results show that UniControl often surpasses the performance of single-task-controlled methods of comparable model sizes. This control versatility positions UniControl as a significant advancement in the realm of controllable visual generation.

img

Updates

MultiGen-20M Datasets

There are more than 20M image-prompt-condition triplets here with total size > 2TB. It includes all 12 tasks (Canny, HED, Sketch, Depth, Normal, Skeleton, Bbox, Seg, Outpainting, Inpainting, Deblurring, Colorization) which are fully released.

Instruction

Environment Preparation

Setup the env first (need to wait a few minutes).

conda env create -f environment.yaml
conda activate unicontrol

Checkpoint Preparation

The checkpoint of pre-trained UniControl model is saved at ./ckpts/unicontrol.ckpt.

cd ckpts
wget https://storage.googleapis.com/sfr-unicontrol-data-research/unicontrol.ckpt 

You can also use the latest trained model (ckpt and safetensors)

wget https://storage.googleapis.com/sfr-unicontrol-data-research/unicontrol_v1.1.ckpt
wget https://storage.googleapis.com/sfr-unicontrol-data-research/unicontrol_v1.1.st

If you want to train from scratch, please follow the ControlNet to prepare the checkpoint initialization. ControlNet provides a simple script for you to achieve this easily. If your SD filename is ./ckpts/v1-5-pruned.ckpt and you want the script to save the processed model (SD+ControlNet) at location ./ckpts/control_sd15_ini.ckpt, you can just run:

python tool_add_control.py ./ckpts/v1-5-pruned.ckpt ./ckpts/control_sd15_ini.ckpt

Data Preparation

Please download the training dataset (MultiGen-20M) to ./multigen20m. Please:

cd multigen20m
gsutil -m cp -r gs://sfr-unicontrol-data-research/dataset ./

Then unzip the all the files.

Model Training (CUDA 11.0 and Conda 4.12.0 work)

Training from Scratch:

python train_unicontrol.py --ckpt ./ckpts/control_sd15_ini.ckpt --config ./models/cldm_v15_unicontrol_v11.yaml --lr 1e-5

Model Finetuning:

python train_unicontrol.py --ckpt ./ckpts/unicontrol.ckpt  --config ./models/cldm_v15_unicontrol.yaml --lr 1e-7

Model Inference (CUDA 11.0 and Conda 4.12.0 work)

For different tasks, please run the code as follows. If you meet OOM error, please decrease the "--num_samples".

If you use safetensors model, you can load the model following ./load_model/load_safetensors_model.py

Canny to Image Generation:

python inference_demo.py --ckpt ./ckpts/unicontrol.ckpt --task canny

HED Edge to Image Generation:

python inference_demo.py --ckpt ./ckpts/unicontrol.ckpt --task hed 

HED-like Skech to Image Generation:

python inference_demo.py --ckpt ./ckpts/unicontrol.ckpt --task hedsketch

Depth Map to Image Generation:

python inference_demo.py --ckpt ./ckpts/unicontrol.ckpt --task depth 

Normal Surface Map to Image Generation:

python inference_demo.py --ckpt ./ckpts/unicontrol.ckpt --task normal

Segmentation Map to Image Generation:

python inference_demo.py --ckpt ./ckpts/unicontrol.ckpt --task seg

Human Skeleton to Image Generation:

python inference_demo.py --ckpt ./ckpts/unicontrol.ckpt --task openpose

Object Bounding Boxes to Image Generation:

python inference_demo.py --ckpt ./ckpts/unicontrol.ckpt --task bbox

Image Outpainting:

python inference_demo.py --ckpt ./ckpts/unicontrol.ckpt --task outpainting

Image Inpainting:

python inference_demo.py --ckpt ./ckpts/unicontrol.ckpt --task inpainting

Image Deblurring:

python inference_demo.py --ckpt ./ckpts/unicontrol.ckpt --task blur

Image Colorization:

python inference_demo.py --ckpt ./ckpts/unicontrol.ckpt --task grayscale

Gradio Demo (App Demo Video, CUDA 11.0 and Conda 4.12.0 work)

We have provided gradio demos for different tasks to use. The example images are saved at ./test_imgs.

For all the tasks (Canny, HED, Sketch, Depth, Normal, Human Pose, Seg, Bbox, Outpainting, Colorization, Deblurring, Inpainting) please run the following code:

python app/gradio_all_tasks.py

We support the direct condition-to-image generation (as shown above). Please unmark the Condition Extraction in UI if you want to upload condition image directly.


Or, we provide the task-specifc gradio demos:

Canny to Image Generation:

python app/gradio_canny2image.py

HED Edge to Image Generation:

python app/gradio_hed2image.py

HED-like Skech to Image Generation:

python app/gradio_hedsketch2image.py

Depth Map to Image Generation:

python app/gradio_depth2image.py

Normal Surface Map to Image Generation:

python app/gradio_normal2image.py

Segmentation Map to Image Generation:

python app/gradio_seg2image.py

Human Skeleton to Image Generation:

python app/gradio_pose2image.py

Object Bounding Boxes to Image Generation:

python app/gradio_bbox2image.py

Image Outpainting:

python app/gradio_outpainting.py

Image Colorization:

python app/gradio_colorization.py

Image Deblurring:

python app/gradio_deblur.py

Image Inpainting:

python app/gradio_inpainting.py

To Do

Tips

Citation

If you find this project useful for your research, please kindly cite our paper:

@article{qin2023unicontrol,
  title={UniControl: A Unified Diffusion Model for Controllable Visual Generation In the Wild},
  author={Qin, Can and Zhang, Shu and Yu, Ning and Feng, Yihao and Yang, Xinyi and Zhou, Yingbo and Wang, Huan and Niebles, Juan Carlos and Xiong, Caiming and Savarese, Silvio and others},
  journal={arXiv preprint arXiv:2305.11147},
  year={2023}
}

Acknowledgement

This project is built upon the gaint sholders of ControlNet and Stable Diffusion. Great thanks to them!

Stable Diffusion https://github.com/CompVis/stable-diffusion

ControlNet https://github.com/lllyasviel/ControlNet

StyleGAN3 https://github.com/NVlabs/stylegan3