IrisRainbowNeko / HCP-Diffusion

A universal Stable-Diffusion toolbox
Apache License 2.0
896 stars 75 forks source link

RuntimeError: shape '[616, 1, 40]' is invalid for input of size 49280 #18

Open biasnhbi opened 1 year ago

biasnhbi commented 1 year ago

╭─────────────────── Traceback (most recent call last) ────────────────────╮
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/runpy.py:196 in           │
│ _run_module_as_main                                                      │
│                                                                          │
│   193 │   main_globals = sys.modules["__main__"].__dict__                │
│   194 │   if alter_argv:                                                 │
│   195 │   │   sys.argv[0] = mod_spec.origin                              │
│ ❱ 196 │   return _run_code(code, main_globals, None,                     │
│   197 │   │   │   │   │    "__main__", mod_spec)                         │
│   198                                                                    │
│   199 def run_module(mod_name, init_globals=None,                        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/runpy.py:86 in _run_code  │
│                                                                          │
│    83 │   │   │   │   │      __loader__ = loader,                        │
│    84 │   │   │   │   │      __package__ = pkg_name,                     │
│    85 │   │   │   │   │      __spec__ = mod_spec)                        │
│ ❱  86 │   exec(code, run_globals)                                        │
│    87 │   return run_globals                                             │
│    88                                                                    │
│    89 def _run_module_code(code, init_globals=None,                      │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/hcpdiff/tra │
│ in_ac_single.py:105 in <module>                                          │
│                                                                          │
│   102 │                                                                  │
│   103 │   conf = load_config_with_cli(args.cfg, args_list=sys.argv[3:])  │
│   104 │   trainer = TrainerSingleCard(conf)                              │
│ ❱ 105 │   trainer.train()                                                │
│   106                                                                    │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/hcpdiff/tra │
│ in_ac.py:409 in train                                                    │
│                                                                          │
│   406 │   │                                                              │
│   407 │   │   loss_sum = np.ones(30)                                     │
│   408 │   │   for data_list in self.train_loader_group:                  │
│ ❱ 409 │   │   │   loss = self.train_one_step(data_list)                  │
│   410 │   │   │   loss_sum[self.global_step%len(loss_sum)] = loss        │
│   411 │   │   │                                                          │
│   412 │   │   │   self.global_step += 1                                  │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/hcpdiff/tra │
│ in_ac.py:501 in train_one_step                                           │
│                                                                          │
│   498 │   │   │   │   other_datas = {k:v.to(self.device, dtype=self.weig │
│   499 │   │   │   │                                                      │
│   500 │   │   │   │   latents = self.get_latents(image, self.train_loade │
│ ❱ 501 │   │   │   │   model_pred, target, timesteps = self.forward(laten │
│   502 │   │   │   │   loss = self.get_loss(model_pred, target, timesteps │
│   503 │   │   │   │   self.accelerator.backward(loss)                    │
│   504                                                                    │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/hcpdiff/tra │
│ in_ac.py:479 in forward                                                  │
│                                                                          │
│   476 │   │                                                              │
│   477 │   │   # CFG context for DreamArtist                              │
│   478 │   │   noisy_latents, timesteps = self.cfg_context.pre(noisy_late │
│ ❱ 479 │   │   model_pred = self.encode_decode(prompt_ids, noisy_latents, │
│   480 │   │   model_pred = self.cfg_context.post(model_pred)             │
│   481 │   │                                                              │
│   482 │   │   # Get the target for loss depending on the prediction type │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/hcpdiff/tra │
│ in_ac_single.py:78 in encode_decode                                      │
│                                                                          │
│    75 │   │   │   │   feeder(input_all)                                  │
│    76 │   │                                                              │
│    77 │   │   encoder_hidden_states = self.text_encoder(prompt_ids, outp │
│ ❱  78 │   │   model_pred = self.unet(noisy_latents, timesteps, encoder_h │
│    79 │   │   return model_pred                                          │
│    80 │                                                                  │
│    81 │   def get_loss(self, model_pred, target, timesteps, att_mask):   │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/nn/mo │
│ dules/module.py:1130 in _call_impl                                       │
│                                                                          │
│   1127 │   │   # this function, and just call forward.                   │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or se │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_h │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                 │
│   1131 │   │   # Do not call functions when jit is used                  │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []     │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/accelerate/ │
│ utils/operations.py:553 in forward                                       │
│                                                                          │
│   550 │   model_forward = ConvertOutputsToFp32(model_forward)            │
│   551 │                                                                  │
│   552 │   def forward(*args, **kwargs):                                  │
│ ❱ 553 │   │   return model_forward(*args, **kwargs)                      │
│   554 │                                                                  │
│   555 │   # To act like a decorator so that it can be popped when doing  │
│   556 │   forward.__wrapped__ = model_forward                            │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/accelerate/ │
│ utils/operations.py:541 in __call__                                      │
│                                                                          │
│   538 │   │   update_wrapper(self, model_forward)                        │
│   539 │                                                                  │
│   540 │   def __call__(self, *args, **kwargs):                           │
│ ❱ 541 │   │   return convert_to_fp32(self.model_forward(*args, **kwargs) │
│   542 │                                                                  │
│   543 │   def __getstate__(self):                                        │
│   544 │   │   raise pickle.PicklingError(                                │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/amp/a │
│ utocast_mode.py:12 in decorate_autocast                                  │
│                                                                          │
│     9 │   @functools.wraps(func)                                         │
│    10 │   def decorate_autocast(*args, **kwargs):                        │
│    11 │   │   with autocast_instance:                                    │
│ ❱  12 │   │   │   return func(*args, **kwargs)                           │
│    13 │   decorate_autocast.__script_unsupported = '@autocast() decorato │
│    14 │   return decorate_autocast                                       │
│    15                                                                    │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/unet_2d_condition.py:481 in forward                                │
│                                                                          │
│   478 │   │   down_block_res_samples = (sample,)                         │
│   479 │   │   for downsample_block in self.down_blocks:                  │
│   480 │   │   │   if hasattr(downsample_block, "has_cross_attention") an │
│ ❱ 481 │   │   │   │   sample, res_samples = downsample_block(            │
│   482 │   │   │   │   │   hidden_states=sample,                          │
│   483 │   │   │   │   │   temb=emb,                                      │
│   484 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,   │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/nn/mo │
│ dules/module.py:1130 in _call_impl                                       │
│                                                                          │
│   1127 │   │   # this function, and just call forward.                   │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or se │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_h │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                 │
│   1131 │   │   # Do not call functions when jit is used                  │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []     │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/unet_2d_blocks.py:781 in forward                                   │
│                                                                          │
│    778 │   │   │   │   │   return custom_forward                         │
│    779 │   │   │   │                                                     │
│    780 │   │   │   │   hidden_states = torch.utils.checkpoint.checkpoint │
│ ❱  781 │   │   │   │   hidden_states = torch.utils.checkpoint.checkpoint │
│    782 │   │   │   │   │   create_custom_forward(attn, return_dict=False │
│    783 │   │   │   │   │   hidden_states,                                │
│    784 │   │   │   │   │   encoder_hidden_states,                        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/hcpdiff/tra │
│ in_ac.py:48 in checkpoint_fix                                            │
│                                                                          │
│    45 # fix checkpoint bug for train part of model                       │
│    46 import torch.utils.checkpoint                                      │
│    47 def checkpoint_fix(function, *args, use_reentrant: bool = False, c │
│ ❱  48 │   return checkpoint_raw(function, *args, use_reentrant=use_reent │
│    49 torch.utils.checkpoint.checkpoint = checkpoint_fix                 │
│    50                                                                    │
│    51 class Trainer:                                                     │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/utils │
│ /checkpoint.py:237 in checkpoint                                         │
│                                                                          │
│   234 │   if use_reentrant:                                              │
│   235 │   │   return CheckpointFunction.apply(function, preserve, *args) │
│   236 │   else:                                                          │
│ ❱ 237 │   │   return _checkpoint_without_reentrant(                      │
│   238 │   │   │   function,                                              │
│   239 │   │   │   preserve,                                              │
│   240 │   │   │   *args                                                  │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/utils │
│ /checkpoint.py:383 in _checkpoint_without_reentrant                      │
│                                                                          │
│   380 │   │   return storage.pop(x)                                      │
│   381 │                                                                  │
│   382 │   with torch.autograd.graph.saved_tensors_hooks(pack, unpack):   │
│ ❱ 383 │   │   output = function(*args)                                   │
│   384 │   │   if torch.cuda._initialized and preserve_rng_state and not  │
│   385 │   │   │   # Cuda was not initialized before running the forward, │
│   386 │   │   │   # stash the CUDA state.                                │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/unet_2d_blocks.py:774 in custom_forward                            │
│                                                                          │
│    771 │   │   │   │   def create_custom_forward(module, return_dict=Non │
│    772 │   │   │   │   │   def custom_forward(*inputs):                  │
│    773 │   │   │   │   │   │   if return_dict is not None:               │
│ ❱  774 │   │   │   │   │   │   │   return module(*inputs, return_dict=re │
│    775 │   │   │   │   │   │   else:                                     │
│    776 │   │   │   │   │   │   │   return module(*inputs)                │
│    777                                                                   │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/nn/mo │
│ dules/module.py:1130 in _call_impl                                       │
│                                                                          │
│   1127 │   │   # this function, and just call forward.                   │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or se │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_h │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                 │
│   1131 │   │   # Do not call functions when jit is used                  │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []     │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/transformer_2d.py:265 in forward                                   │
│                                                                          │
│   262 │   │                                                              │
│   263 │   │   # 2. Blocks                                                │
│   264 │   │   for block in self.transformer_blocks:                      │
│ ❱ 265 │   │   │   hidden_states = block(                                 │
│   266 │   │   │   │   hidden_states,                                     │
│   267 │   │   │   │   encoder_hidden_states=encoder_hidden_states,       │
│   268 │   │   │   │   timestep=timestep,                                 │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/nn/mo │
│ dules/module.py:1130 in _call_impl                                       │
│                                                                          │
│   1127 │   │   # this function, and just call forward.                   │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or se │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_h │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                 │
│   1131 │   │   # Do not call functions when jit is used                  │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []     │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/attention.py:307 in forward                                        │
│                                                                          │
│   304 │   │   │   )                                                      │
│   305 │   │   │                                                          │
│   306 │   │   │   # 2. Cross-Attention                                   │
│ ❱ 307 │   │   │   attn_output = self.attn2(                              │
│   308 │   │   │   │   norm_hidden_states,                                │
│   309 │   │   │   │   encoder_hidden_states=encoder_hidden_states,       │
│   310 │   │   │   │   attention_mask=attention_mask,                     │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/nn/mo │
│ dules/module.py:1130 in _call_impl                                       │
│                                                                          │
│   1127 │   │   # this function, and just call forward.                   │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or se │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_h │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                 │
│   1131 │   │   # Do not call functions when jit is used                  │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []     │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:        │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/cross_attention.py:160 in forward                                  │
│                                                                          │
│   157 │   │   # The `CrossAttention` class can call different attention  │
│   158 │   │   # here we simply pass along all tensors to the selected pr │
│   159 │   │   # For standard processors that are defined here, `**cross_ │
│ ❱ 160 │   │   return self.processor(                                     │
│   161 │   │   │   self,                                                  │
│   162 │   │   │   hidden_states,                                         │
│   163 │   │   │   encoder_hidden_states=encoder_hidden_states,           │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/diffusers/m │
│ odels/cross_attention.py:374 in __call__                                 │
│                                                                          │
│   371 │   │   key = attn.head_to_batch_dim(key).contiguous()             │
│   372 │   │   value = attn.head_to_batch_dim(value).contiguous()         │
│   373 │   │                                                              │
│ ❱ 374 │   │   hidden_states = xformers.ops.memory_efficient_attention(   │
│   375 │   │   │   query, key, value, attn_bias=attention_mask, op=self.a │
│   376 │   │   )                                                          │
│   377 │   │   hidden_states = hidden_states.to(query.dtype)              │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/xformers/op │
│ s/fmha/__init__.py:192 in memory_efficient_attention                     │
│                                                                          │
│   189 │   │   and options.                                               │
│   190 │   :return: multi-head attention Tensor with shape ``[B, Mq, H, K │
│   191 │   """                                                            │
│ ❱ 192 │   return _memory_efficient_attention(                            │
│   193 │   │   Inputs(                                                    │
│   194 │   │   │   query=query, key=key, value=value, p=p, attn_bias=attn │
│   195 │   │   ),                                                         │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/xformers/op │
│ s/fmha/__init__.py:295 in _memory_efficient_attention                    │
│                                                                          │
│   292 │   │   )                                                          │
│   293 │                                                                  │
│   294 │   output_shape = inp.normalize_bmhk()                            │
│ ❱ 295 │   return _fMHA.apply(                                            │
│   296 │   │   op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, i │
│   297 │   ).reshape(output_shape)                                        │
│   298                                                                    │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/xformers/op │
│ s/fmha/__init__.py:41 in forward                                         │
│                                                                          │
│    38 │   │   op_fw = op[0] if op is not None else None                  │
│    39 │   │   op_bw = op[1] if op is not None else None                  │
│    40 │   │                                                              │
│ ❱  41 │   │   out, op_ctx = _memory_efficient_attention_forward_requires │
│    42 │   │   │   inp=inp, op=op_fw                                      │
│    43 │   │   )                                                          │
│    44                                                                    │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/xformers/op │
│ s/fmha/__init__.py:323 in                                                │
│ _memory_efficient_attention_forward_requires_grad                        │
│                                                                          │
│   320 │   │   op = _dispatch_fw(inp)                                     │
│   321 │   else:                                                          │
│   322 │   │   _ensure_op_supports_or_raise(ValueError, "memory_efficient │
│ ❱ 323 │   out = op.apply(inp, needs_gradient=True)                       │
│   324 │   assert out[1] is not None                                      │
│   325 │   return (out[0].reshape(output_shape), out[1])                  │
│   326                                                                    │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/xformers/op │
│ s/fmha/flash.py:235 in apply                                             │
│                                                                          │
│   232 │   │   │   max_seqlen_q,                                          │
│   233 │   │   │   cu_seqlens_k,                                          │
│   234 │   │   │   max_seqlen_k,                                          │
│ ❱ 235 │   │   ) = _convert_input_format(inp)                             │
│   236 │   │   out, softmax_lse, rng_state = cls.OPERATOR(                │
│   237 │   │   │   inp.query,                                             │
│   238 │   │   │   inp.key,                                               │
│                                                                          │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/xformers/op │
│ s/fmha/flash.py:177 in _convert_input_format                             │
│                                                                          │
│   174 │   new_inp = replace(                                             │
│   175 │   │   inp,                                                       │
│   176 │   │   query=query.reshape([batch * seqlen_q, num_heads, head_dim │
│ ❱ 177 │   │   key=key.reshape([batch * seqlen_kv, num_heads, head_dim_q] │
│   178 │   │   value=value.reshape([batch * seqlen_kv, num_heads, head_di │
│   179 │   )                                                              │
│   180 │   softmax_scale = inp.query.shape[-1] ** (-0.5) if inp.scale is  │
╰──────────────────────────────────────────────────────────────────────────╯
RuntimeError: shape '[616, 1, 40]' is invalid for input of size 49280       
``` `