huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.12k stars 5.19k forks source link

`SDXLCFGCutoffCallback` does not work with `StableDiffusionXLControlNetPipeline` #8686

Open rootonchair opened 2 months ago

rootonchair commented 2 months ago

Describe the bug

Running CFGCutoffCallback with ControlNet SDXL will raise following error

diffusers/src/diffusers/models/attention.py:372, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, added_cond_kwargs)
    364         norm_hidden_states = self.pos_embed(norm_hidden_states)
    366     attn_output = self.attn2(
    367         norm_hidden_states,
    368         encoder_hidden_states=encoder_hidden_states,
    369         attention_mask=encoder_attention_mask,
    370         **cross_attention_kwargs,
    371     )
--> 372     hidden_states = attn_output + hidden_states
    374 # 4. Feed-forward
    375 # i2vgen doesn't have this norm 🤷‍♂️
    376 if self.norm_type == "ada_norm_continuous":

RuntimeError: The size of tensor a (8192) must match the size of tensor b (4096) at non-singleton dimension 1

which occurs due to conditional image (https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L1488) is not converted back to batch 1.

So the solution would be either adding new Callback for ControlNet or fixing current Callback to convert image back to shape 1

Reproduction

from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
from diffusers.callbacks import SDXLCFGCutoffCallback
from diffusers.utils import load_image, make_image_grid
from PIL import Image
import cv2
import numpy as np
import torch

original_image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
)

image = np.array(original_image)

low_threshold = 100
high_threshold = 200

image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)

controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0",
    torch_dtype=torch.float16,
    use_safetensors=True
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    torch_dtype=torch.float16,
    use_safetensors=True
)
pipe.enable_model_cpu_offload()

prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
negative_prompt = 'low quality, bad quality, sketches'
callback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4)
image = pipe(
    prompt,
    negative_prompt=negative_prompt,
    image=canny_image,
    controlnet_conditioning_scale=0.5,
    callback_on_step_end=callback,
).images[0]
make_image_grid([original_image, canny_image, image], rows=1, cols=3)

Logs

No response

System Info

Who can help?

@sayakpaul @yiyixuxu

asomoza commented 2 months ago

Thanks for reporting this issue, I'll look into it, probably the best option here is to make it work with the same callback for all the SDXL related pipelines.

juancopi81 commented 2 weeks ago

Hi there! I just had this problem. Any update here? Or code that we could use to overwrite the callback meanwhile. Thanks a lot.

asomoza commented 2 weeks ago

by a coincidence I'm just working on this, I'll open a PR soon