huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.53k stars 5.29k forks source link

the main branch is broken for controlnet training with sdxl #4206

Closed yutongli closed 1 year ago

yutongli commented 1 year ago

Describe the bug

the main branch (all PRs up to #4205) failed the controlnet training with sdxl. Some kind of disalignment of arguments.

Reproduction

export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9" export VAE_DIR="madebyollin/sdxl-vae-fp16-fix" export OUTPUT_DIR="product_train_output_extract_1stbatch_100k_sdxl0.9_1024_lr1" export CACHE_DIR="/home/ubuntu/lamda-filesystem/custom_cache"

accelerate launch --mixed_precision="fp16" train_controlnet_sdxl.py \ --pretrained_model_name_or_path=$MODEL_DIR \ --output_dir=$OUTPUT_DIR \ --pretrained_vae_model_name_or_path=$VAE_DIR \ --cache_dir=$CACHE_DIR \ --dataset_name=all_training_full_extract \ --image_column="target" \ --conditioning_image_column="source" \ --caption_column="prompt" \ --resolution=1024 \ --learning_rate=1e-5 \ --validation_image "./val1_extract_source.jpg" "./val2_extract_source.jpg" "./val3_extract_source.jpg" "./popchange.png" \ --validation_prompt "a white trash can sitting on a table next to a plant" "a bottle of liquid with flower in it" "a rack with a bunch of shoes on it" "a doll in galaxy" \ --train_batch_size=1 \ --gradient_accumulation_steps=12 \ --tracker_project_name="product_train_output_extract_1stbatch_100k_sdxl0.9_1024_lr1" \ --num_train_epochs=20 \ --report_to=wandb \ --validation_steps=100 \ --checkpointing_steps=1000 \ --checkpoints_total_limit=10 \ --seed=42 \ --enable_xformers_memory_efficient_attention

Logs

Traceback (most recent call last):
  File "train_controlnet_sdxl.py", line 1248, in <module>
    main(args)
  File "train_controlnet_sdxl.py", line 1212, in main
    image_logs = log_validation(
  File "train_controlnet_sdxl.py", line 126, in log_validation
    image = pipeline(
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/lamda-filesystem/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py", line 782, in __call__
    self.check_inputs(
  File "/home/ubuntu/lamda-filesystem/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py", line 450, in check_inputs
    raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
ValueError: `prompt_2` has to be of type `str` or `list` but is <class 'PIL.Image.Image'>

System Info

Who can help?

@patrickvonplaten @sayakpaul

sayakpaul commented 1 year ago

Cc @patrickvonplaten

Maybe the recent introduction of prompt_2 needs to be validated a bit. Will look into this. Could try from https://github.com/huggingface/diffusers/pull/4038 ?

sayakpaul commented 1 year ago

https://github.com/huggingface/diffusers/pull/4223 should have fixed this issue.