Got the following error when trying to use the Notebook (as is, no modifications). 5th cell, the one running pipe(...)
RuntimeError: Expected query.size(0) == key.size(0) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Any ideas?
Full trace below:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[9], line 1
----> 1 images = pipe(prompt,negative_prompt,
2 batch_size = 2, #batch size
3 num_inference_steps=30, # sampling step
4 height = 896,
5 width = 640,
6 end_steps = 1, # The number of steps to end the attention double version (specified in a ratio of 0-1. If it is 1, attention double version will be applied in all steps, with 0 being the normal generation)
7 base_ratio=0.2, # Base ratio, the weight of base prompt, if 0, all are regional prompts, if 1, all are base prompts
8 seed = 4396, # random seed
9 )
Cell In[1], line 108, in RegionalGenerator.__call__(self, prompts, negative_prompt, batch_size, height, width, guidance_scale, num_inference_steps, seed, base_ratio, end_steps)
106 #predict noise
107 with torch.no_grad():
--> 108 noise_pred = self.unet(sample = latent_model_input,timestep = t,encoder_hidden_states=text_embs).sample
110 #negative CFG
111 noise_pred_text, noise_pred_negative= noise_pred.chunk(2)
File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py:905, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, encoder_attention_mask, return_dict)
903 for downsample_block in self.down_blocks:
904 if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
--> 905 sample, res_samples = downsample_block(
906 hidden_states=sample,
907 temb=emb,
908 encoder_hidden_states=encoder_hidden_states,
909 attention_mask=attention_mask,
910 cross_attention_kwargs=cross_attention_kwargs,
911 encoder_attention_mask=encoder_attention_mask,
912 )
913 else:
914 sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py:993, in CrossAttnDownBlock2D.forward(self, hidden_states, temb, encoder_hidden_states, attention_mask, cross_attention_kwargs, encoder_attention_mask)
991 else:
992 hidden_states = resnet(hidden_states, temb)
--> 993 hidden_states = attn(
994 hidden_states,
995 encoder_hidden_states=encoder_hidden_states,
996 cross_attention_kwargs=cross_attention_kwargs,
997 attention_mask=attention_mask,
998 encoder_attention_mask=encoder_attention_mask,
999 return_dict=False,
1000 )[0]
1002 output_states = output_states + (hidden_states,)
1004 if self.downsamplers is not None:
File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/transformer_2d.py:291, in Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict)
289 # 2. Blocks
290 for block in self.transformer_blocks:
--> 291 hidden_states = block(
292 hidden_states,
293 attention_mask=attention_mask,
294 encoder_hidden_states=encoder_hidden_states,
295 encoder_attention_mask=encoder_attention_mask,
296 timestep=timestep,
297 cross_attention_kwargs=cross_attention_kwargs,
298 class_labels=class_labels,
299 )
301 # 3. Output
302 if self.is_input_continuous:
File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/attention.py:170, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels)
165 if self.attn2 is not None:
166 norm_hidden_states = (
167 self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
168 )
--> 170 attn_output = self.attn2(
171 norm_hidden_states,
172 encoder_hidden_states=encoder_hidden_states,
173 attention_mask=encoder_attention_mask,
174 **cross_attention_kwargs,
175 )
176 hidden_states = attn_output + hidden_states
178 # 3. Feed-forward
File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/attention_processor.py:321, in Attention.forward(self, hidden_states, encoder_hidden_states, attention_mask, **cross_attention_kwargs)
317 def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
318 # The `Attention` class can call different attention processors / attention functions
319 # here we simply pass along all tensors to the selected processor class
320 # For standard processors that are defined here, `**cross_attention_kwargs` is empty
--> 321 return self.processor(
322 self,
323 hidden_states,
324 encoder_hidden_states=encoder_hidden_states,
325 attention_mask=attention_mask,
326 **cross_attention_kwargs,
327 )
File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/attention_processor.py:1046, in XFormersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb)
1043 key = attn.head_to_batch_dim(key).contiguous()
1044 value = attn.head_to_batch_dim(value).contiguous()
-> 1046 hidden_states = xformers.ops.memory_efficient_attention(
1047 query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1048 )
1049 hidden_states = hidden_states.to(query.dtype)
1050 hidden_states = attn.batch_to_head_dim(hidden_states)
File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:197, in memory_efficient_attention(query, key, value, attn_bias, p, scale, op)
117 def memory_efficient_attention(
118 query: torch.Tensor,
119 key: torch.Tensor,
(...)
125 op: Optional[AttentionOp] = None,
126 ) -> torch.Tensor:
127 """Implements the memory-efficient attention mechanism following
128 `"Self-Attention Does Not Need O(n^2) Memory" <[http://arxiv.org/abs/2112.05682>`_](http://arxiv.org/abs/2112.05682%3E%60_).
129
(...)
195 :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
196 """
--> 197 return _memory_efficient_attention(
198 Inputs(
199 query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
200 ),
201 op=op,
202 )
File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:293, in _memory_efficient_attention(inp, op)
288 def _memory_efficient_attention(
289 inp: Inputs, op: Optional[AttentionOp] = None
290 ) -> torch.Tensor:
291 # fast-path that doesn't require computing the logsumexp for backward computation
292 if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
--> 293 return _memory_efficient_attention_forward(
294 inp, op=op[0] if op is not None else None
295 )
297 output_shape = inp.normalize_bmhk()
298 return _fMHA.apply(
299 op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale
300 ).reshape(output_shape)
File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:313, in _memory_efficient_attention_forward(inp, op)
310 else:
311 _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
--> 313 out, *_ = op.apply(inp, needs_gradient=False)
314 return out.reshape(output_shape)
File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/cutlass.py:106, in FwOp.apply(cls, inp, needs_gradient)
104 causal = isinstance(inp.attn_bias, LowerTriangularMask)
105 cu_seqlen_k, cu_seqlen_q, max_seqlen_q = _get_seqlen_info(inp)
--> 106 out, lse = cls.OPERATOR(
107 query=inp.query,
108 key=inp.key,
109 value=inp.value,
110 cu_seqlens_q=cu_seqlen_q,
111 cu_seqlens_k=cu_seqlen_k,
112 max_seqlen_q=max_seqlen_q,
113 compute_logsumexp=needs_gradient,
114 causal=causal,
115 scale=inp.scale,
116 )
117 ctx: Optional[Context] = None
118 if needs_gradient:
File ~/jupyter/.venv/lib/python3.10/site-packages/torch/_ops.py:442, in OpOverloadPacket.__call__(self, *args, **kwargs)
437 def __call__(self, *args, **kwargs):
438 # overloading __call__ to ensure torch.ops.foo.bar()
439 # is still callable from JIT
440 # We save the function ptr as the `op` attribute on
441 # OpOverloadPacket to access it here.
--> 442 return self._op(*args, **kwargs or {})
RuntimeError: Expected query.size(0) == key.size(0) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Got the following error when trying to use the Notebook (as is, no modifications). 5th cell, the one running
pipe(...)
Any ideas?
Full trace below: