This repo contains PyTorch model definitions, pre-trained weights and inference/sampling code for our paper exploring Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation. You can find more visualizations on our project page.
PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation
Junsong Chen*, Chongjian Ge*, Enze Xie*†, Yue Wu*, Lewei Yao, Xiaozhe Ren, Zhongdao Wang, Ping Luo, Huchuan Lu, Zhenguo Li
Huawei Noah’s Ark Lab, DLUT, HKU, HKUST
Learning from the previous PixArt-α project, we will try to keep this repo as simple as possible so that everyone in the PixArt community can use it.
🧨 diffusers
using patches for fast experience!-Main
-Guidance
-Others
Model | T5 token length | VAE | 2K/4K |
---|---|---|---|
PixArt-Σ | 300 | SDXL | ✅ |
PixArt-α | 120 | SD1.5 | ❌ |
Model | Sample-1 | Sample-2 | Sample-3 |
---|---|---|---|
PixArt-Σ | |||
PixArt-α | |||
Prompt | Close-up, gray-haired, bearded man in 60s, observing passersby, in wool coat and brown beret, glasses, cinematic. | Body shot, a French woman, Photography, French Streets background, backlight, rim light, Fujifilm. | Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee. |
conda create -n pixart python==3.9.0
conda activate pixart
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
git clone https://github.com/PixArt-alpha/PixArt-sigma.git
cd PixArt-sigma
pip install -r requirements.txt
First of all.
We start a new repo to build a more user friendly and more compatible codebase. The main model structure is the same as PixArt-α, you can still develop your function base on the original repo. lso, This repo will support PixArt-alpha in the future.
[!TIP]
Now you can train your model without prior feature extraction. We reform the data structure in PixArt-α code base, so that everyone can start to train & inference & visualize at the very beginning without any pain.
Download the toy dataset first. The dataset structure for training is:
cd ./pixart-sigma-toy-dataset
Dataset Structure
├──InternImgs/ (images are saved here)
│ ├──000000000000.png
│ ├──000000000001.png
│ ├──......
├──InternData/
│ ├──data_info.json (meta data)
Optional(👇)
│ ├──img_sdxl_vae_features_1024resolution_ms_new (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
│ │ ├──000000000000.npy
│ │ ├──000000000001.npy
│ │ ├──......
│ ├──caption_features_new
│ │ ├──000000000000.npz
│ │ ├──000000000001.npz
│ │ ├──......
│ ├──sharegpt4v_caption_features_new (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
│ │ ├──000000000000.npz
│ │ ├──000000000001.npz
│ │ ├──......
# SDXL-VAE, T5 checkpoints
git lfs install
git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers
# PixArt-Sigma checkpoints
python tools/download.py # environment eg. HF_ENDPOINT=https://hf-mirror.com can use for HuggingFace mirror
Selecting your desired config file from config files dir.
python -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 \
train_scripts/train.py \
configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py \
--load-from output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth \
--work-dir output/your_first_pixart-exp \
--debug
To get started, first install the required dependencies. Make sure you've downloaded the checkpoint files
from models(coming soon) to the output/pretrained_models
folder,
and then run on your local machine:
# SDXL-VAE, T5 checkpoints
git lfs install
git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers output/pixart_sigma_sdxlvae_T5_diffusers
# PixArt-Sigma checkpoints
python tools/download.py
# demo launch
python scripts/interface.py --model_path output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth --image_size 512 --port 11223
[!IMPORTANT]
Upgrade yourdiffusers
to make thePixArtSigmaPipeline
available!pip install git+https://github.com/huggingface/diffusers
For
diffusers<0.28.0
, check this script for help.import torch from diffusers import Transformer2DModel, PixArtSigmaPipeline
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") weight_dtype = torch.float16
transformer = Transformer2DModel.from_pretrained( "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", subfolder='transformer', torch_dtype=weight_dtype, use_safetensors=True, ) pipe = PixArtSigmaPipeline.from_pretrained( "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", transformer=transformer, torch_dtype=weight_dtype, use_safetensors=True, ) pipe.to(device)
prompt = "A small cactus with a happy face in the Sahara desert." image = pipe(prompt).images[0] image.save("./catcus.png")
## 3. PixArt Demo
```bash
pip install git+https://github.com/huggingface/diffusers
# PixArt-Sigma 1024px
DEMO_PORT=12345 python app/app_pixart_sigma.py
# PixArt-Sigma One step Sampler(DMD)
DEMO_PORT=12345 python app/app_pixart_dmd.py
Let's have a look at a simple example using the http://your-server-ip:12345
.
Directly download from Hugging Face
or run with:
pip install git+https://github.com/huggingface/diffusers
python tools/convert_pixart_to_diffusers.py --orig_ckpt_path output/pretrained_models/PixArt-Sigma-XL-2-1024-MS.pth --dump_path output/pretrained_models/PixArt-Sigma-XL-2-1024-MS --only_transformer=True --image_size=1024 --version sigma
All models will be automatically downloaded here. You can also choose to download manually from this url.
Model | #Params | Checkpoint path | Download in OpenXLab |
---|---|---|---|
T5 & SDXL-VAE | 4.5B | Diffusers: pixart_sigma_sdxlvae_T5_diffusers | [coming soon]( ) |
PixArt-Σ-256 | 0.6B | pth: PixArt-Sigma-XL-2-256x256.pth Diffusers: PixArt-Sigma-XL-2-256x256 |
[coming soon]( ) |
PixArt-Σ-512 | 0.6B | pth: PixArt-Sigma-XL-2-512-MS.pth Diffusers: PixArt-Sigma-XL-2-512-MS |
[coming soon]( ) |
PixArt-α-512-DMD | 0.6B | Diffusers: PixArt-Alpha-DMD-XL-2-512x512 | [coming soon]( ) |
PixArt-Σ-1024 | 0.6B | pth: PixArt-Sigma-XL-2-1024-MS.pth Diffusers: PixArt-Sigma-XL-2-1024-MS |
[coming soon]( ) |
PixArt-Σ-2K | 0.6B | pth: PixArt-Sigma-XL-2-2K-MS.pth Diffusers: PixArt-Sigma-XL-2-2K-MS |
[coming soon]( ) |
We will try our best to release
@misc{chen2024pixartsigma,
title={PixArt-\Sigma: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation},
author={Junsong Chen and Chongjian Ge and Enze Xie and Yue Wu and Lewei Yao and Xiaozhe Ren and Zhongdao Wang and Ping Luo and Huchuan Lu and Zhenguo Li},
year={2024},
eprint={2403.04692},
archivePrefix={arXiv},
primaryClass={cs.CV}