ai-forever / Kandinsky-2

Kandinsky 2 — multilingual text2image latent diffusion model
Apache License 2.0
2.77k stars 306 forks source link

Bug: controlnet generate image error when guidance_scale <= 1 #85

Open chuck-ma opened 1 year ago

chuck-ma commented 1 year ago

I found that an error occurs when guidance_scale <= 1: "Sizes of tensors must match except in dimension 1. Expected size 1 but got size 2 for tensor number 1 in the list."

After reviewing the code, I found that when guidance_scale <= 1, the shape[0] of the sample is 1, while the shape[0] of the hint is 2.

image

If that's the case, using torch.cat to concatenate the sample and hint along dim=1 will definitely result in an error.

Here is the code which can reproduce the error.

`import torch import numpy as np

from transformers import pipeline from diffusers.utils import load_image

from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline

def make_hint(image, depth_estimator): image = depth_estimator(image)["depth"] image = np.array(image) image = image[:, :, None] image = np.concatenate([image, image, image], axis=2) detected_map = torch.from_numpy(image).float() / 255.0 hint = detected_map.permute(2, 0, 1) return hint

img = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png" ).resize((768, 768))

depth_estimator = pipeline("depth-estimation") hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")

pipe_prior = KandinskyV22PriorPipeline.from_pretrained( "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ) pipe_prior = pipe_prior.to("cuda")

pipe = KandinskyV22ControlnetPipeline.from_pretrained( "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16 ) pipe = pipe.to("cuda")

prompt = "A robot, 4k photo" negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"

generator = torch.Generator(device="cuda").manual_seed(43) image_emb, zero_image_emb = pipe_prior( prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator ).to_tuple()

images = pipe( image_embeds=image_emb, negative_image_embeds=zero_image_emb, hint=hint, num_inference_steps=50, generator=generator, height=768, width=768, guidance_scale=1.0 ).images images[0].save("robot_cat.png")`