xyfJASON / ctrlora

Codebase for "CtrLoRA: An Extensible and Efficient Framework for Controllable Image Generation"
Apache License 2.0
142 stars 5 forks source link
controllable-generation controlnet image-to-image lora

banner

style-transfer

The images are compressed for loading speed.

CtrLoRA

CtrLoRA: An Extensible and Efficient Framework for Controllable Image Generation \ Yifeng Xu1,2, Zhenliang He1, Shiguang Shan1,2, Xilin Chen1,2 \ 1Key Lab of AI Safety, Institute of Computing Technology, CAS, China \ 2University of Chinese Academy of Sciences, China

base-conditions

We first train a Base ControlNet along with condition-specific LoRAs on base conditions with a large-scale dataset. Then, our Base ControlNet can be efficiently adapted to novel conditions by new LoRAs with as few as 1,000 images and less than 1 hour on a single GPU.

🎨 Visual Results

🎨 Controllable generation on "base conditions"

base-conditions

🎨 Controllable generation on "novel conditions"

novel-conditions

🎨 Integration into community models & Multi-conditional generation

integration

🎨 Application to style transfer

style-transfer

πŸ› οΈ Installation

Clone this repo:

git clone --depth 1 https://github.com/xyfJASON/ctrlora.git
cd ctrlora

Create and activate a new conda environment:

conda create -n ctrlora python=3.10
conda activate ctrlora

Install pytorch and other dependencies:

pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
pip install -r requirements.txt

πŸ€–οΈ Download Pretrained Models

We provide our pretrained models here. Please put the Base ControlNet (ctrlora_sd15_basecn700k.ckpt) into ./ckpts/ctrlora-basecn and the LoRAs into ./ckpts/ctrlora-loras. The naming convention of the LoRAs is ctrlora_sd15_<basecn>_<condition>.ckpt for base conditions and ctrlora_sd15_<basecn>_<condition>_<images>_<steps>.ckpt for novel conditions.

You also need to download the SD1.5-based Models and put them into ./ckpts/sd15. Models used in our work:

πŸš€ Gradio Demo

python app/gradio_ctrlora.py

Requires at least 9GB/21GB GPU RAM to generate a batch of one/four 512x512 images.

πŸš€ Single-conditional generation

  1. select the Stable Diffusion checkpoint, Base Controlnet checkpoint and LoRA checkpoint.
  2. write prompts and negative prompts. We provide several commonly used prompts.
  3. prepare a condition image
    • upload an image to the left of the "Condition" panel, select the preprocessor corresponding to the LoRA, and click "Detect".
    • or upload the condition image directly, select the "none" preprocessor, and click "Detect".
  4. click "Run" to generate images.
  5. if you upload any new checkpoints, restart gradio or click "Refresh".
gradio

πŸš€ Multi-conditional generation

gradio2

πŸš€ Application to style transfer

  1. select a stylized Stable Diffusion checkpoint to specify the target style, e.g., Pixel.
  2. select the Base ControlNet checkpoint.
  3. select palette for the LoRA1 checkpoint and lineart for the LoRA2 checkpoint.
    • palette + canny or palette + hed also work, maybe there are more interesting combinations to be discovered
  4. write prompts and negative prompts.
  5. upload the source image to the "Condition 1" panel, select the "none" preprocessor, and click "Detect".
  6. upload the source image to the "Condition 2" panel, select the "lineart" preprocessor, and click "Detect".
  7. adjust the weights for the two conditions in the "Basic options" panel.
  8. click "Run" to generate images.
gradio3

πŸ”₯ Train a LoRA for Your Custom Condition

Based on our Base ControlNet, you can train a LoRA for your custom condition with as few as 1,000 images and less than 1 hour on a single GPU (20GB).

First, download the Stable Diffusion v1.5 (v1-5-pruned.ckpt) into ./ckpts/sd15 and the Base ControlNet (ctrlora_sd15_basecn700k.ckpt) into ./ckpts/ctrlora-basecn as described above.

Second, put your custom data into ./data/<custom_data_name> with the following structure:

data
└── custom_data_name
    β”œβ”€β”€ prompt.json
    β”œβ”€β”€ source
    β”‚   β”œβ”€β”€ 0000.jpg
    β”‚   β”œβ”€β”€ 0001.jpg
    β”‚   └── ...
    └── target
        β”œβ”€β”€ 0000.jpg
        β”œβ”€β”€ 0001.jpg
        └── ...

Third, run the following command to train the LoRA for your custom condition:

python scripts/train_ctrlora_finetune.py \
    --dataroot ./data/<custom_data_name> \
    --config ./configs/ctrlora_finetune_sd15_rank128.yaml \
    --sd_ckpt ./ckpts/sd15/v1-5-pruned.ckpt \
    --cn_ckpt ./ckpts/ctrlora-basecn/ctrlora_sd15_basecn700k.ckpt \
    [--name NAME] \
    [--max_steps MAX_STEPS]

After training, extract the LoRA weights with the following command:

python scripts/tool_extract_weights.py -t lora --ckpt CHECKPOINT --save_path SAVE_PATH

Finally, put the extracted LoRA into ./ckpts/ctrlora-loras and use it in the Gradio demo.

πŸ“š Detailed Instructions

Please refer to the instructions here for more details of training, fine-tuning, and evaluation.

πŸͺ§ Acknowledgement

This project is built upon Stable Diffusion, ControlNet, and UniControl. Thanks for their great work!

πŸ–‹οΈ Citation

If you find this project helpful, please consider citing:

@article{xu2024ctrlora,
  title={CtrLoRA: An Extensible and Efficient Framework for Controllable Image Generation},
  author={Xu, Yifeng and He, Zhenliang and Shan, Shiguang and Chen, Xilin},
  journal={arXiv preprint arXiv:2410.09400},
  year={2024}
}