huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.06k stars 5.18k forks source link

Shape Error with Transformer2DModel and "adanorm" #7575

Open will-rice opened 5 months ago

will-rice commented 5 months ago

Describe the bug

I have this shape error where scale and shift are different sizes from x. I'm not positive this is a bug so it could be me doing something wrong. Thanks in advance for taking a look at this.

Reproduction

import torch
from diffusers import Transformer2DModel

model = Transformer2DModel(
    in_channels=4,
    out_channels=4,
    patch_size=2,
    sample_size=32,
    norm_type="ada_norm",
    num_embeds_ada_norm=1000,
    cross_attention_dim=1024,
)
timesteps = torch.randint(0, 1000, (8,), dtype=torch.long)

model(
    torch.randn(8, 4, 32, 32),
    timestep=timesteps,
    encoder_hidden_states=torch.randn(8, 18, 1024),
)

Logs

RuntimeError                              Traceback (most recent call last)
Cell In[10], line 15
      4 model = Transformer2DModel(
      5     in_channels=4,
      6     out_channels=4,
   (...)
     11     cross_attention_dim=1024,
     12 )
     13 timesteps = torch.randint(0, 1000, (8,), dtype=torch.long)
---> 15 model(
     16     torch.randn(8, 4, 32, 32),
     17     timestep=timesteps,
     18     encoder_hidden_states=torch.randn(8, 18, 1024),
     19 )

File ~/.pyenv/versions/notebook/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/notebook/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.pyenv/versions/notebook/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:397, in Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict)
    385         hidden_states = torch.utils.checkpoint.checkpoint(
    386             create_custom_forward(block),
    387             hidden_states,
   (...)
    394             **ckpt_kwargs,
    395         )
    396     else:
--> 397         hidden_states = block(
    398             hidden_states,
    399             attention_mask=attention_mask,
    400             encoder_hidden_states=encoder_hidden_states,
    401             encoder_attention_mask=encoder_attention_mask,
    402             timestep=timestep,
    403             cross_attention_kwargs=cross_attention_kwargs,
    404             class_labels=class_labels,
    405         )
    407 # 3. Output
    408 if self.is_input_continuous:

File ~/.pyenv/versions/notebook/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/notebook/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.pyenv/versions/notebook/lib/python3.10/site-packages/diffusers/models/attention.py:303, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, added_cond_kwargs)
    300 batch_size = hidden_states.shape[0]
    302 if self.norm_type == "ada_norm":
--> 303     norm_hidden_states = self.norm1(hidden_states, timestep)
    304 elif self.norm_type == "ada_norm_zero":
    305     norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
    306         hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
    307     )

File ~/.pyenv/versions/notebook/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/notebook/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/.pyenv/versions/notebook/lib/python3.10/site-packages/diffusers/models/normalization.py:47, in AdaLayerNorm.forward(self, x, timestep)
     45 emb = self.linear(self.silu(self.emb(timestep)))
     46 scale, shift = torch.chunk(emb, 2)
---> 47 x = self.norm(x) * (1 + scale) + shift
     48 return x

RuntimeError: The size of tensor a (1408) must match the size of tensor b (2816) at non-singleton dimension 2

System Info

Who can help?

@DN6 @yiyixuxu @sayakpaul

DN6 commented 5 months ago

@sayakpaul could you take a look here please.

sayakpaul commented 5 months ago

@DN6 https://github.com/huggingface/diffusers/pull/7578

github-actions[bot] commented 4 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.