instantX-research / CSGO

CSGO: Content-Style Composition in Text-to-Image Generation 🔥
https://csgo-gen.github.io/
244 stars 8 forks source link

CSGO: Content-Style Composition in Text-to-Image Generation

[**Peng Xing**](https://github.com/xingp-ng)12* · [**Haofan Wang**](https://haofanwang.github.io/)1* · [**Yanpeng Sun**](https://scholar.google.com.hk/citations?user=a3FI8c4AAAAJ&hl=zh-CN&oi=ao/)2 · [**Qixun Wang**](https://github.com/wangqixun)1 · [**Xu Bai**](https://huggingface.co/baymin0220)13 · [**Hao Ai**](https://github.com/aihao2000)14 · [**Renyuan Huang**](https://github.com/DannHuang)15 [**Zechao Li**](https://zechao-li.github.io/)2✉ 1InstantX Team · 2Nanjing University of Science and Technology · 3Xiaohongshu · 4Beihang University · 5Peking University *equal contributions, corresponding authors [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/InstantX/CSGO) [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-App-red)](https://huggingface.co/spaces/xingpng/CSGO/) [![GitHub](https://img.shields.io/github/stars/instantX-research/CSGO?style=social)](https://github.com/instantX-research/CSGO)

Updates 🔥

Plan 💪

Introduction 📖

This repo, named CSGO, contains the official PyTorch implementation of our paper CSGO: Content-Style Composition in Text-to-Image Generation. We are actively updating and improving this repository. If you find any bugs or have suggestions, welcome to raise issues or submit pull requests (PR) 💖.

Pipeline 💻

Capabilities 🚅

🔥 Our CSGO achieves image-driven style transfer, text-driven stylized synthesis, and text editing-driven stylized synthesis.

🔥 For more results, visit our homepage 🔥

Getting Started 🏁

1. Clone the code and prepare the environment

git clone https://github.com/instantX-research/CSGO
cd CSGO

# create env using conda
conda create -n CSGO python=3.9
conda activate CSGO

# install dependencies with pip
# for Linux and Windows users
pip install -r requirements.txt

2. Download pretrained weights

We currently release two model weights.

Mode content token style token Other
csgo.bin 4 16 -
csgo_4_32.bin 4 32 Deepspeed zero2
csgo_4_32_v2.bin 4 32 Deepspeed zero2+more(coming soon)

The easiest way to download the pretrained weights is from HuggingFace:

# first, ensure git-lfs is installed, see: https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage
git lfs install
# clone and move the weights
git clone https://huggingface.co/InstantX/CSGO

Our method is fully compatible with SDXL, VAE, ControlNet, and Image Encoder. Please download them and place them in the ./base_models folder.

tips:If you expect to load Controlnet directly using ControlNetPipeline as in CSGO, do the following:

git clone https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic
mv TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v2_fp16.safetensors TTPLanet_SDXL_Controlnet_Tile_Realistic/diffusion_pytorch_model.safetensors

3. Inference 🚀

import torch
from ip_adapter.utils import BLOCKS as BLOCKS
from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS
from PIL import Image
from diffusers import (
    AutoencoderKL,
    ControlNetModel,
    StableDiffusionXLControlNetPipeline,

)
from ip_adapter import CSGO

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

base_model_path =  "./base_models/stable-diffusion-xl-base-1.0"  
image_encoder_path = "./base_models/IP-Adapter/sdxl_models/image_encoder"
csgo_ckpt = "./CSGO/csgo.bin"
pretrained_vae_name_or_path ='./base_models/sdxl-vae-fp16-fix'
controlnet_path = "./base_models/TTPLanet_SDXL_Controlnet_Tile_Realistic"
weight_dtype = torch.float16

vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    base_model_path,
    controlnet=controlnet,
    torch_dtype=torch.float16,
    add_watermarker=False,
    vae=vae
)
pipe.enable_vae_tiling()

target_content_blocks = BLOCKS['content']
target_style_blocks = BLOCKS['style']
controlnet_target_content_blocks = controlnet_BLOCKS['content']
controlnet_target_style_blocks = controlnet_BLOCKS['style']

csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4,num_style_tokens=32,
                          target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,controlnet=False,controlnet_adapter=True,
                              controlnet_target_content_blocks=controlnet_target_content_blocks, 
                              controlnet_target_style_blocks=controlnet_target_style_blocks,
                              content_model_resampler=True,
                              style_model_resampler=True,
                              load_controlnet=False,

                              )

style_name = 'img_0.png'
content_name = 'img_0.png'
style_image = "../assets/{}".format(style_name)
content_image = Image.open('../assets/{}'.format(content_name)).convert('RGB')

caption ='a small house with a sheep statue on top of it'

num_sample=4

#image-driven style transfer
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
                           prompt=caption,
                           negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
                           content_scale=1.0,
                           style_scale=1.0,
                           guidance_scale=10,
                           num_images_per_prompt=num_sample,
                           num_samples=1,
                           num_inference_steps=50,
                           seed=42,
                           image=content_image.convert('RGB'),
                           controlnet_conditioning_scale=0.6,
                          )

#text-driven stylized synthesis
caption='a cat'
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
                           prompt=caption,
                           negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
                           content_scale=1.0,
                           style_scale=1.0,
                           guidance_scale=10,
                           num_images_per_prompt=num_sample,
                           num_samples=1,
                           num_inference_steps=50,
                           seed=42,
                           image=content_image.convert('RGB'),
                           controlnet_conditioning_scale=0.01,
                          )

#text editing-driven stylized synthesis
caption='a small house'
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
                           prompt=caption,
                           negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
                           content_scale=1.0,
                           style_scale=1.0,
                           guidance_scale=10,
                           num_images_per_prompt=num_sample,
                           num_samples=1,
                           num_inference_steps=50,
                           seed=42,
                           image=content_image.convert('RGB'),
                           controlnet_conditioning_scale=0.4,
                          )

4 Gradio interface ⚙️

We also provide a Gradio interface for a better experience, just run by:

# For Linux and Windows users (and macOS)
python gradio/app.py 

If you don't have the resources to configure it, we provide an online demo.

Demos


🔥 For more results, visit our homepage 🔥

Content-Style Composition

Cycle Translation

Text-Driven Style Synthesis

Text Editing-Driven Style Synthesis

Star History

Star History Chart

Acknowledgements

This project is developed by InstantX Team and Xiaohongshu, all copyright reserved. Sincere thanks to xiaohongshu for providing the computing resources.

Citation 💖

If you find CSGO useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:

@article{xing2024csgo,
       title={CSGO: Content-Style Composition in Text-to-Image Generation}, 
       author={Peng Xing and Haofan Wang and Yanpeng Sun and Qixun Wang and Xu Bai and Hao Ai and Renyuan Huang and Zechao Li},
       year={2024},
       journal = {arXiv 2408.16766},
}