huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.01k stars 5.17k forks source link

Mismatching size in matmul when using StableDiffusionInstructPix2Pix pipeline with IP-Adapter #7799

Closed misshimichka closed 4 months ago

misshimichka commented 4 months ago

Describe the bug

I've tried to combine InstructPix2Pix model and IP-Adapter (pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin"), but I always get RuntimeError: mat1 and mat2 shapes cannot be multiplied (771x1280 and 1024x3072). When I use other models i.e. Stable Diffusion or SD-XL everything works. \ I think it's because of different dimensions of InstructPix2Pix and IP-Adapter.

Reproduction

noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)

vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch.float32)

pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
    "timbrooks/instruct-pix2pix",
    torch_dtype=torch.float32,
    scheduler=noise_scheduler,
    vae=vae,
    feature_extractor=None,
    safety_checker=None
)
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")

image = load_image("https://huggingface.co/datasets/huggingface/documentation/images/resolve/main/diffusers/load_neg_embed.png")

generator = torch.Generator(device="cpu").manual_seed(33)
images = pipe(
    prompt='best quality, high quality',
    image=image,
    ip_adapter_image=image,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
    num_inference_steps=50,
    generator=generator
).images[0]

Logs

RuntimeError                              Traceback (most recent call last)
Cell In[32], line 2
      1 generator = torch.Generator(device="cpu").manual_seed(33)
----> 2 images = pipe(
      3     prompt='best quality, high quality',
      4     image=image,
      5     ip_adapter_image=image,
      6     negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
      7     num_inference_steps=50,
      8     generator=generator
      9 ).images[0]
     10 images

File /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py:404, in StableDiffusionInstructPix2PixPipeline.__call__(self, prompt, image, num_inference_steps, guidance_scale, image_guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, ip_adapter_image, output_type, return_dict, callback_on_step_end, callback_on_step_end_tensor_inputs, **kwargs)
    401 scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
    403 # predict the noise residual
--> 404 noise_pred = self.unet(
    405     scaled_latent_model_input,
    406     t,
    407     encoder_hidden_states=prompt_embeds,
    408     added_cond_kwargs=added_cond_kwargs,
    409     return_dict=False,
    410 )[0]
    412 # perform guidance
    413 if self.do_classifier_free_guidance:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py:1164, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, down_intrablock_additional_residuals, encoder_attention_mask, return_dict)
   1161 if self.time_embed_act is not None:
   1162     emb = self.time_embed_act(emb)
-> 1164 encoder_hidden_states = self.process_encoder_hidden_states(
   1165     encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
   1166 )
   1168 # 2. pre-process
   1169 sample = self.conv_in(sample)

File /opt/conda/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py:1035, in UNet2DConditionModel.process_encoder_hidden_states(self, encoder_hidden_states, added_cond_kwargs)
   1031         raise ValueError(
   1032             f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
   1033         )
   1034     image_embeds = added_cond_kwargs.get("image_embeds")
-> 1035     image_embeds = self.encoder_hid_proj(image_embeds)
   1036     encoder_hidden_states = (encoder_hidden_states, image_embeds)
   1037 return encoder_hidden_states

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/diffusers/models/embeddings.py:909, in MultiIPAdapterImageProjection.forward(self, image_embeds)
    907 batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
    908 image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
--> 909 image_embed = image_projection_layer(image_embed)
    910 image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
    912 projected_image_embeds.append(image_embed)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/diffusers/models/embeddings.py:458, in ImageProjection.forward(self, image_embeds)
    455 batch_size = image_embeds.shape[0]
    457 # image
--> 458 image_embeds = self.image_embeds(image_embeds)
    459 image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
    460 image_embeds = self.norm(image_embeds)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (771x1280 and 1024x3072)

System Info

- `diffusers` version: 0.27.2
- Platform: Linux-5.15.133+-x86_64-with-glibc2.31
- Python version: 3.10.13
- PyTorch version (GPU?): 2.1.2 (True)
- Huggingface_hub version: 0.22.2
- Transformers version: 4.39.3
- Accelerate version: 0.29.3
- xFormers version: not installed
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No

Who can help?

No response

yiyixuxu commented 4 months ago

thanks! I can reproduce this, looking into it now!

yiyixuxu commented 4 months ago

hey I fixed in #7820 let me know if it works for you

misshimichka commented 4 months ago

@yiyixuxu thank you so much for a very quick answer! It works for me too! Looking forward to seeing edited version of pipeline in main branch :)