YangLing0818 / RPG-DiffusionMaster

[ICML 2024] Mastering Text-to-Image Diffusion: Recaptioning, Planning, and Generating with Multimodal LLMs (RPG)
https://proceedings.mlr.press/v235/yang24ai.html
MIT License
1.7k stars 99 forks source link

RuntimeError: Expected query.size(0) == key.size(0) to be true, but got false #31

Open adammenges opened 9 months ago

adammenges commented 9 months ago

Got the following error when trying to use the Notebook (as is, no modifications). 5th cell, the one running pipe(...)

RuntimeError: Expected query.size(0) == key.size(0) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Any ideas?

Full trace below:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 1
----> 1 images = pipe(prompt,negative_prompt,
      2               batch_size = 2, #batch size
      3               num_inference_steps=30, # sampling step
      4               height = 896, 
      5               width = 640, 
      6               end_steps = 1, # The number of steps to end the attention double version (specified in a ratio of 0-1. If it is 1, attention double version will be applied in all steps, with 0 being the normal generation)
      7               base_ratio=0.2, # Base ratio, the weight of base prompt, if 0, all are regional prompts, if 1, all are base prompts
      8               seed = 4396, # random seed
      9 )

Cell In[1], line 108, in RegionalGenerator.__call__(self, prompts, negative_prompt, batch_size, height, width, guidance_scale, num_inference_steps, seed, base_ratio, end_steps)
    106 #predict noise
    107 with torch.no_grad():
--> 108     noise_pred = self.unet(sample = latent_model_input,timestep = t,encoder_hidden_states=text_embs).sample
    110 #negative CFG
    111 noise_pred_text, noise_pred_negative= noise_pred.chunk(2)

File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py:905, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, encoder_attention_mask, return_dict)
    903 for downsample_block in self.down_blocks:
    904     if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
--> 905         sample, res_samples = downsample_block(
    906             hidden_states=sample,
    907             temb=emb,
    908             encoder_hidden_states=encoder_hidden_states,
    909             attention_mask=attention_mask,
    910             cross_attention_kwargs=cross_attention_kwargs,
    911             encoder_attention_mask=encoder_attention_mask,
    912         )
    913     else:
    914         sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py:993, in CrossAttnDownBlock2D.forward(self, hidden_states, temb, encoder_hidden_states, attention_mask, cross_attention_kwargs, encoder_attention_mask)
    991     else:
    992         hidden_states = resnet(hidden_states, temb)
--> 993         hidden_states = attn(
    994             hidden_states,
    995             encoder_hidden_states=encoder_hidden_states,
    996             cross_attention_kwargs=cross_attention_kwargs,
    997             attention_mask=attention_mask,
    998             encoder_attention_mask=encoder_attention_mask,
    999             return_dict=False,
   1000         )[0]
   1002     output_states = output_states + (hidden_states,)
   1004 if self.downsamplers is not None:

File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/transformer_2d.py:291, in Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict)
    289 # 2. Blocks
    290 for block in self.transformer_blocks:
--> 291     hidden_states = block(
    292         hidden_states,
    293         attention_mask=attention_mask,
    294         encoder_hidden_states=encoder_hidden_states,
    295         encoder_attention_mask=encoder_attention_mask,
    296         timestep=timestep,
    297         cross_attention_kwargs=cross_attention_kwargs,
    298         class_labels=class_labels,
    299     )
    301 # 3. Output
    302 if self.is_input_continuous:

File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/attention.py:170, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels)
    165 if self.attn2 is not None:
    166     norm_hidden_states = (
    167         self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
    168     )
--> 170     attn_output = self.attn2(
    171         norm_hidden_states,
    172         encoder_hidden_states=encoder_hidden_states,
    173         attention_mask=encoder_attention_mask,
    174         **cross_attention_kwargs,
    175     )
    176     hidden_states = attn_output + hidden_states
    178 # 3. Feed-forward

File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/attention_processor.py:321, in Attention.forward(self, hidden_states, encoder_hidden_states, attention_mask, **cross_attention_kwargs)
    317 def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
    318     # The `Attention` class can call different attention processors / attention functions
    319     # here we simply pass along all tensors to the selected processor class
    320     # For standard processors that are defined here, `**cross_attention_kwargs` is empty
--> 321     return self.processor(
    322         self,
    323         hidden_states,
    324         encoder_hidden_states=encoder_hidden_states,
    325         attention_mask=attention_mask,
    326         **cross_attention_kwargs,
    327     )

File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/attention_processor.py:1046, in XFormersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb)
   1043 key = attn.head_to_batch_dim(key).contiguous()
   1044 value = attn.head_to_batch_dim(value).contiguous()
-> 1046 hidden_states = xformers.ops.memory_efficient_attention(
   1047     query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
   1048 )
   1049 hidden_states = hidden_states.to(query.dtype)
   1050 hidden_states = attn.batch_to_head_dim(hidden_states)

File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:197, in memory_efficient_attention(query, key, value, attn_bias, p, scale, op)
    117 def memory_efficient_attention(
    118     query: torch.Tensor,
    119     key: torch.Tensor,
   (...)
    125     op: Optional[AttentionOp] = None,
    126 ) -> torch.Tensor:
    127     """Implements the memory-efficient attention mechanism following
    128     `"Self-Attention Does Not Need O(n^2) Memory" <[http://arxiv.org/abs/2112.05682>`_](http://arxiv.org/abs/2112.05682%3E%60_).
    129 
   (...)
    195     :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
    196     """
--> 197     return _memory_efficient_attention(
    198         Inputs(
    199             query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
    200         ),
    201         op=op,
    202     )

File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:293, in _memory_efficient_attention(inp, op)
    288 def _memory_efficient_attention(
    289     inp: Inputs, op: Optional[AttentionOp] = None
    290 ) -> torch.Tensor:
    291     # fast-path that doesn't require computing the logsumexp for backward computation
    292     if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
--> 293         return _memory_efficient_attention_forward(
    294             inp, op=op[0] if op is not None else None
    295         )
    297     output_shape = inp.normalize_bmhk()
    298     return _fMHA.apply(
    299         op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale
    300     ).reshape(output_shape)

File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:313, in _memory_efficient_attention_forward(inp, op)
    310 else:
    311     _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
--> 313 out, *_ = op.apply(inp, needs_gradient=False)
    314 return out.reshape(output_shape)

File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/cutlass.py:106, in FwOp.apply(cls, inp, needs_gradient)
    104 causal = isinstance(inp.attn_bias, LowerTriangularMask)
    105 cu_seqlen_k, cu_seqlen_q, max_seqlen_q = _get_seqlen_info(inp)
--> 106 out, lse = cls.OPERATOR(
    107     query=inp.query,
    108     key=inp.key,
    109     value=inp.value,
    110     cu_seqlens_q=cu_seqlen_q,
    111     cu_seqlens_k=cu_seqlen_k,
    112     max_seqlen_q=max_seqlen_q,
    113     compute_logsumexp=needs_gradient,
    114     causal=causal,
    115     scale=inp.scale,
    116 )
    117 ctx: Optional[Context] = None
    118 if needs_gradient:

File ~/jupyter/.venv/lib/python3.10/site-packages/torch/_ops.py:442, in OpOverloadPacket.__call__(self, *args, **kwargs)
    437 def __call__(self, *args, **kwargs):
    438     # overloading __call__ to ensure torch.ops.foo.bar()
    439     # is still callable from JIT
    440     # We save the function ptr as the `op` attribute on
    441     # OpOverloadPacket to access it here.
--> 442     return self._op(*args, **kwargs or {})

RuntimeError: Expected query.size(0) == key.size(0) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)