huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.12k stars 5.19k forks source link

AttributeError: 'tuple' object has no attribute 'shape' #8863

Open innat-asj opened 2 months ago

innat-asj commented 2 months ago

Describe the bug

I've built the following pipeline with inpainting and Ip-adapter. And to reduce the memory footprint I also enabled CPU offload and efficient transformer methods, which causes the reported error.

Reproduction

from diffusers import AutoPipelineForInpainting

pipe = AutoPipelineForInpainting.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
    safety_checker=None,
    torch_dtype=torch.float16
)
pipe.load_ip_adapter(
    "h94/IP-Adapter", 
    subfolder="models", 
    weight_name="ip-adapter-plus_sd15.bin"
)
pipe.set_ip_adapter_scale(0.7)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()

mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_mask.png")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_bear_1.png")
ip_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_gummy.png")

generator = torch.Generator(device="cuda").manual_seed(4)
images = pipe(
    prompt="a cute gummy bear waving",
    image=image,
    mask_image=mask_image,
    ip_adapter_image=ip_image,
    generator=generator,
    num_inference_steps=10,
).images
images[0]

Logs


```shell
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[26], line 2
      1 generator = torch.Generator(device="cuda").manual_seed(4)
----> 2 images = pipe(
      3     prompt="a cute gummy bear waving",
      4     image=image,
      5     mask_image=mask_image,
      6     ip_adapter_image=ip_image,
      7     generator=generator,
      8     num_inference_steps=10,
      9 ).images
     10 images[0]

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py:1384, in StableDiffusionInpaintPipeline.__call__(self, prompt, image, mask_image, masked_image_latents, height, width, padding_mask_crop, strength, num_inference_steps, timesteps, sigmas, guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, output_type, return_dict, cross_attention_kwargs, clip_skip, callback_on_step_end, callback_on_step_end_tensor_inputs, **kwargs)
   1381     latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
   1383 # predict the noise residual
-> 1384 noise_pred = self.unet(
   1385     latent_model_input,
   1386     t,
   1387     encoder_hidden_states=prompt_embeds,
   1388     timestep_cond=timestep_cond,
   1389     cross_attention_kwargs=self.cross_attention_kwargs,
   1390     added_cond_kwargs=added_cond_kwargs,
   1391     return_dict=False,
   1392 )[0]
   1394 # perform guidance
   1395 if self.do_classifier_free_guidance:

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py:1209, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, down_intrablock_additional_residuals, encoder_attention_mask, return_dict)
   1206     if is_adapter and len(down_intrablock_additional_residuals) > 0:
   1207         additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
-> 1209     sample, res_samples = downsample_block(
   1210         hidden_states=sample,
   1211         temb=emb,
   1212         encoder_hidden_states=encoder_hidden_states,
   1213         attention_mask=attention_mask,
   1214         cross_attention_kwargs=cross_attention_kwargs,
   1215         encoder_attention_mask=encoder_attention_mask,
   1216         **additional_residuals,
   1217     )
   1218 else:
   1219     sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py:1288, in CrossAttnDownBlock2D.forward(self, hidden_states, temb, encoder_hidden_states, attention_mask, cross_attention_kwargs, encoder_attention_mask, additional_residuals)
   1286 else:
   1287     hidden_states = resnet(hidden_states, temb)
-> 1288     hidden_states = attn(
   1289         hidden_states,
   1290         encoder_hidden_states=encoder_hidden_states,
   1291         cross_attention_kwargs=cross_attention_kwargs,
   1292         attention_mask=attention_mask,
   1293         encoder_attention_mask=encoder_attention_mask,
   1294         return_dict=False,
   1295     )[0]
   1297 # apply additional residuals to the output of the last pair of resnet and attention blocks
   1298 if i == len(blocks) - 1 and additional_residuals is not None:

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:440, in Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict)
    428         hidden_states = torch.utils.checkpoint.checkpoint(
    429             create_custom_forward(block),
    430             hidden_states,
   (...)
    437             **ckpt_kwargs,
    438         )
    439     else:
--> 440         hidden_states = block(
    441             hidden_states,
    442             attention_mask=attention_mask,
    443             encoder_hidden_states=encoder_hidden_states,
    444             encoder_attention_mask=encoder_attention_mask,
    445             timestep=timestep,
    446             cross_attention_kwargs=cross_attention_kwargs,
    447             class_labels=class_labels,
    448         )
    450 # 3. Output
    451 if self.is_input_continuous:

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/diffusers/models/attention.py:490, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, added_cond_kwargs)
    487     if self.pos_embed is not None and self.norm_type != "ada_norm_single":
    488         norm_hidden_states = self.pos_embed(norm_hidden_states)
--> 490     attn_output = self.attn2(
    491         norm_hidden_states,
    492         encoder_hidden_states=encoder_hidden_states,
    493         attention_mask=encoder_attention_mask,
    494         **cross_attention_kwargs,
    495     )
    496     hidden_states = attn_output + hidden_states
    498 # 4. Feed-forward
    499 # i2vgen doesn't have this norm 🤷‍♂️

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/diffusers/models/attention_processor.py:559, in Attention.forward(self, hidden_states, encoder_hidden_states, attention_mask, **cross_attention_kwargs)
    554     logger.warning(
    555         f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
    556     )
    557 cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
--> 559 return self.processor(
    560     self,
    561     hidden_states,
    562     encoder_hidden_states=encoder_hidden_states,
    563     attention_mask=attention_mask,
    564     **cross_attention_kwargs,
    565 )

File /opt/conda/envs/pytorch/lib/python3.10/site-packages/diffusers/models/attention_processor.py:1355, in XFormersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb, *args, **kwargs)
   1351     batch_size, channel, height, width = hidden_states.shape
   1352     hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
   1354 batch_size, key_tokens, _ = (
-> 1355     hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
   1356 )
   1358 attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
   1359 if attention_mask is not None:
   1360     # expand our mask's singleton query_tokens dimension:
   1361     #   [batch*heads,            1, key_tokens] ->
   (...)
   1364     #   [batch*heads, query_tokens, key_tokens]
   1365     # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.

AttributeError: 'tuple' object has no attribute 'shape'

System Info

Who can help?

@sayakpaul @yiyixuxu @DN6

Mainly it's caused for

pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
tolgacangoz commented 2 months ago

PyTorch 2.2 has come with Flash Attention-v2. This is better for both memory and speed than PyTorch 1.x's native attention. If you upgrade your PyTorch it might not be necessary to use xFormers. Is it necessary to use PyTorch 1.x for you?

innat-asj commented 2 months ago

I've already mentioned the packages with versions that I need.

tolgacangoz commented 2 months ago

Could you try this for now?


 pipe = AutoPipelineForInpainting.from_pretrained(
     "runwayml/stable-diffusion-inpainting",
     safety_checker=None,
     torch_dtype=torch.float16,
 )
+pipe.enable_xformers_memory_efficient_attention()
 pipe.load_ip_adapter(
     "h94/IP-Adapter", 
     subfolder="models", 
     weight_name="ip-adapter-plus_sd15.bin"
 )
 pipe.set_ip_adapter_scale(0.7)
 pipe.enable_model_cpu_offload()
-pipe.enable_xformers_memory_efficient_attention()
DN6 commented 1 month ago

Hi @innat-asj Could you try what @tolgacangoz is suggesting and move enabling xformers before loading the IP Adapter. The issue seems to be coming from how the Attention Processors are configured.

innat-asj commented 1 month ago

Moving xformers before ip-adapter works. Thanks. The error message could be improved though.

sayakpaul commented 1 month ago

Thanks for the suggestion. Do you maybe want to open a PR?