IrisRainbowNeko / HCP-Diffusion

A universal Stable-Diffusion toolbox
Apache License 2.0
896 stars 75 forks source link

DreamArtist++ 训练报错 #16

Open yoke233 opened 1 year ago

yoke233 commented 1 year ago

cfgs/train/examples/DreamArtist++.yaml

删除 最后 data中的dataset_class

运行 accelerate launch -m hcpdiff.train_ac_single --cfg cfgs/train/examples/DreamArtist++.yaml

报错: Expected query.size(0) == key.size(0) to be true, but got false.

(hcpd) ➜  HCP-Diffusion git:(main) ✗ accelerate launch -m hcpdiff.train_ac_single --cfg cfgs/train/examples/DreamArtist++2.yaml
[11:10:59] WARNING  The following values were not passed to `accelerate launch` and had defaults used instead:              launch.py:890
                            `--num_processes` was set to a value of `1`
                            `--num_machines` was set to a value of `1`
                            `--mixed_precision` was set to a value of `'no'`
                            `--dynamo_backend` was set to a value of `'no'`
                    To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
tensorboard is not available
wandb is not available
2023-05-15 11:11:01.927 | INFO     | hcpdiff.loggers.cli_logger:_info:30 - world_size: 1
2023-05-15 11:11:01.927 | INFO     | hcpdiff.loggers.cli_logger:_info:30 - accumulation: 1
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
2023-05-15 11:11:08.818 | INFO     | hcpdiff.data.bucket:build_buckets_from_images:130 - build buckets from images
/data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:870: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning
  warnings.warn(
2023-05-15 11:11:08.841 | INFO     | hcpdiff.data.bucket:build_buckets_from_images:159 - buckets info: size:[512 512], num:1
2023-05-15 11:11:08.842 | INFO     | __main__:build_data:57 - len(train_dataset): 4
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.87it/s]
2023-05-15 11:11:12.413 | INFO     | hcpdiff.loggers.cli_logger:_info:30 - ***** Running training *****
2023-05-15 11:11:12.414 | INFO     | hcpdiff.loggers.cli_logger:_info:30 -   Num batches each epoch = 1
2023-05-15 11:11:12.414 | INFO     | hcpdiff.loggers.cli_logger:_info:30 -   Num Steps = 1000
2023-05-15 11:11:12.414 | INFO     | hcpdiff.loggers.cli_logger:_info:30 -   Instantaneous batch size per device = 4
2023-05-15 11:11:12.414 | INFO     | hcpdiff.loggers.cli_logger:_info:30 -   Total train batch size (w. parallel, distributed & accumulation) = 4
2023-05-15 11:11:12.414 | INFO     | hcpdiff.loggers.cli_logger:_info:30 -   Gradient Accumulation steps = 1
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/runpy.py:196 in _run_module_as_main                │
│                                                                                                  │
│   193 │   main_globals = sys.modules["__main__"].__dict__                                        │
│   194 │   if alter_argv:                                                                         │
│   195 │   │   sys.argv[0] = mod_spec.origin                                                      │
│ ❱ 196 │   return _run_code(code, main_globals, None,                                             │
│   197 │   │   │   │   │    "__main__", mod_spec)                                                 │
│   198                                                                                            │
│   199 def run_module(mod_name, init_globals=None,                                                │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/runpy.py:86 in _run_code                           │
│                                                                                                  │
│    83 │   │   │   │   │      __loader__ = loader,                                                │
│    84 │   │   │   │   │      __package__ = pkg_name,                                             │
│    85 │   │   │   │   │      __spec__ = mod_spec)                                                │
│ ❱  86 │   exec(code, run_globals)                                                                │
│    87 │   return run_globals                                                                     │
│    88                                                                                            │
│    89 def _run_module_code(code, init_globals=None,                                              │
│                                                                                                  │
│ /data/yoke/HCP-Diffusion/hcpdiff/train_ac_single.py:105 in <module>                              │
│                                                                                                  │
│   102 │                                                                                          │
│   103 │   conf = load_config_with_cli(args.cfg, args_list=sys.argv[3:]) # skip --cfg             │
│   104 │   trainer=TrainerSingleCard(conf)                                                        │
│ ❱ 105 │   trainer.train()                                                                        │
│   106                                                                                            │
│                                                                                                  │
│ /data/yoke/HCP-Diffusion/hcpdiff/train_ac.py:409 in train                                        │
│                                                                                                  │
│   406 │   │                                                                                      │
│   407 │   │   loss_sum = np.ones(30)                                                             │
│   408 │   │   for data_list in self.train_loader_group:                                          │
│ ❱ 409 │   │   │   loss = self.train_one_step(data_list)                                          │
│   410 │   │   │   loss_sum[self.global_step%len(loss_sum)] = loss                                │
│   411 │   │   │                                                                                  │
│   412 │   │   │   self.global_step += 1                                                          │
│                                                                                                  │
│ /data/yoke/HCP-Diffusion/hcpdiff/train_ac.py:501 in train_one_step                               │
│                                                                                                  │
│   498 │   │   │   │   other_datas = {k:v.to(self.device, dtype=self.weight_dtype) for k, v in    │
│   499 │   │   │   │                                                                              │
│   500 │   │   │   │   latents = self.get_latents(image, self.train_loader_group.get_dataset(id   │
│ ❱ 501 │   │   │   │   model_pred, target, timesteps = self.forward(latents, prompt_ids, **othe   │
│   502 │   │   │   │   loss = self.get_loss(model_pred, target, timesteps, att_mask) * self.tra   │
│   503 │   │   │   │   self.accelerator.backward(loss)                                            │
│   504                                                                                            │
│                                                                                                  │
│ /data/yoke/HCP-Diffusion/hcpdiff/train_ac.py:479 in forward                                      │
│                                                                                                  │
│   476 │   │                                                                                      │
│   477 │   │   # CFG context for DreamArtist                                                      │
│   478 │   │   noisy_latents, timesteps = self.cfg_context.pre(noisy_latents, timesteps)          │
│ ❱ 479 │   │   model_pred = self.encode_decode(prompt_ids, noisy_latents, timesteps, **kwargs)    │
│   480 │   │   model_pred = self.cfg_context.post(model_pred)                                     │
│   481 │   │                                                                                      │
│   482 │   │   # Get the target for loss depending on the prediction type                         │
│                                                                                                  │
│ /data/yoke/HCP-Diffusion/hcpdiff/train_ac_single.py:78 in encode_decode                          │
│                                                                                                  │
│    75 │   │   │   │   feeder(input_all)                                                          │
│    76 │   │                                                                                      │
│    77 │   │   encoder_hidden_states = self.text_encoder(prompt_ids, output_hidden_states=True)   │
│ ❱  78 │   │   model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample     │
│    79 │   │   return model_pred                                                                  │
│    80 │                                                                                          │
│    81 │   def get_loss(self, model_pred, target, timesteps, att_mask):                           │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 in   │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1127 │   │   # this function, and just call forward.                                           │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/accelerate/utils/operations.py:521   │
│ in forward                                                                                       │
│                                                                                                  │
│   518 │   model_forward = ConvertOutputsToFp32(model_forward)                                    │
│   519 │                                                                                          │
│   520 │   def forward(*args, **kwargs):                                                          │
│ ❱ 521 │   │   return model_forward(*args, **kwargs)                                              │
│   522 │                                                                                          │
│   523 │   # To act like a decorator so that it can be popped when doing `extract_model_from_pa   │
│   524 │   forward.__wrapped__ = model_forward                                                    │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/accelerate/utils/operations.py:509   │
│ in __call__                                                                                      │
│                                                                                                  │
│   506 │   │   update_wrapper(self, model_forward)                                                │
│   507 │                                                                                          │
│   508 │   def __call__(self, *args, **kwargs):                                                   │
│ ❱ 509 │   │   return convert_to_fp32(self.model_forward(*args, **kwargs))                        │
│   510 │                                                                                          │
│   511 │   def __getstate__(self):                                                                │
│   512 │   │   raise pickle.PicklingError(                                                        │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/torch/amp/autocast_mode.py:12 in     │
│ decorate_autocast                                                                                │
│                                                                                                  │
│     9 │   @functools.wraps(func)                                                                 │
│    10 │   def decorate_autocast(*args, **kwargs):                                                │
│    11 │   │   with autocast_instance:                                                            │
│ ❱  12 │   │   │   return func(*args, **kwargs)                                                   │
│    13 │   decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in    │
│    14 │   return decorate_autocast                                                               │
│    15                                                                                            │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.p │
│ y:724 in forward                                                                                 │
│                                                                                                  │
│   721 │   │   down_block_res_samples = (sample,)                                                 │
│   722 │   │   for downsample_block in self.down_blocks:                                          │
│   723 │   │   │   if hasattr(downsample_block, "has_cross_attention") and downsample_block.has   │
│ ❱ 724 │   │   │   │   sample, res_samples = downsample_block(                                    │
│   725 │   │   │   │   │   hidden_states=sample,                                                  │
│   726 │   │   │   │   │   temb=emb,                                                              │
│   727 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,                           │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 in   │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1127 │   │   # this function, and just call forward.                                           │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py:8 │
│ 59 in forward                                                                                    │
│                                                                                                  │
│    856 │   │   │   │   │   return custom_forward                                                 │
│    857 │   │   │   │                                                                             │
│    858 │   │   │   │   hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(  │
│ ❱  859 │   │   │   │   hidden_states = torch.utils.checkpoint.checkpoint(                        │
│    860 │   │   │   │   │   create_custom_forward(attn, return_dict=False),                       │
│    861 │   │   │   │   │   hidden_states,                                                        │
│    862 │   │   │   │   │   encoder_hidden_states,                                                │
│                                                                                                  │
│ /data/yoke/HCP-Diffusion/hcpdiff/train_ac.py:48 in checkpoint_fix                                │
│                                                                                                  │
│    45 # fix checkpoint bug for train part of model                                               │
│    46 import torch.utils.checkpoint                                                              │
│    47 def checkpoint_fix(function, *args, use_reentrant: bool = False, checkpoint_raw = torch.   │
│ ❱  48 │   return checkpoint_raw(function, *args, use_reentrant=use_reentrant, **kwargs)          │
│    49 torch.utils.checkpoint.checkpoint = checkpoint_fix                                         │
│    50                                                                                            │
│    51 class Trainer:                                                                             │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/torch/utils/checkpoint.py:237 in     │
│ checkpoint                                                                                       │
│                                                                                                  │
│   234 │   if use_reentrant:                                                                      │
│   235 │   │   return CheckpointFunction.apply(function, preserve, *args)                         │
│   236 │   else:                                                                                  │
│ ❱ 237 │   │   return _checkpoint_without_reentrant(                                              │
│   238 │   │   │   function,                                                                      │
│   239 │   │   │   preserve,                                                                      │
│   240 │   │   │   *args                                                                          │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/torch/utils/checkpoint.py:383 in     │
│ _checkpoint_without_reentrant                                                                    │
│                                                                                                  │
│   380 │   │   return storage.pop(x)                                                              │
│   381 │                                                                                          │
│   382 │   with torch.autograd.graph.saved_tensors_hooks(pack, unpack):                           │
│ ❱ 383 │   │   output = function(*args)                                                           │
│   384 │   │   if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:         │
│   385 │   │   │   # Cuda was not initialized before running the forward, so we didn't            │
│   386 │   │   │   # stash the CUDA state.                                                        │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py:8 │
│ 52 in custom_forward                                                                             │
│                                                                                                  │
│    849 │   │   │   │   def create_custom_forward(module, return_dict=None):                      │
│    850 │   │   │   │   │   def custom_forward(*inputs):                                          │
│    851 │   │   │   │   │   │   if return_dict is not None:                                       │
│ ❱  852 │   │   │   │   │   │   │   return module(*inputs, return_dict=return_dict)               │
│    853 │   │   │   │   │   │   else:                                                             │
│    854 │   │   │   │   │   │   │   return module(*inputs)                                        │
│    855                                                                                           │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 in   │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1127 │   │   # this function, and just call forward.                                           │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/diffusers/models/transformer_2d.py:2 │
│ 65 in forward                                                                                    │
│                                                                                                  │
│   262 │   │                                                                                      │
│   263 │   │   # 2. Blocks                                                                        │
│   264 │   │   for block in self.transformer_blocks:                                              │
│ ❱ 265 │   │   │   hidden_states = block(                                                         │
│   266 │   │   │   │   hidden_states,                                                             │
│   267 │   │   │   │   encoder_hidden_states=encoder_hidden_states,                               │
│   268 │   │   │   │   timestep=timestep,                                                         │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 in   │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1127 │   │   # this function, and just call forward.                                           │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/diffusers/models/attention.py:331 in │
│ forward                                                                                          │
│                                                                                                  │
│   328 │   │   │   # TODO (Birch-San): Here we should prepare the encoder_attention mask correc   │
│   329 │   │   │   # prepare attention mask here                                                  │
│   330 │   │   │                                                                                  │
│ ❱ 331 │   │   │   attn_output = self.attn2(                                                      │
│   332 │   │   │   │   norm_hidden_states,                                                        │
│   333 │   │   │   │   encoder_hidden_states=encoder_hidden_states,                               │
│   334 │   │   │   │   attention_mask=encoder_attention_mask,                                     │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/torch/nn/modules/module.py:1130 in   │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1127 │   │   # this function, and just call forward.                                           │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/diffusers/models/attention_processor │
│ .py:267 in forward                                                                               │
│                                                                                                  │
│    264 │   │   # The `Attention` class can call different attention processors / attention func  │
│    265 │   │   # here we simply pass along all tensors to the selected processor class           │
│    266 │   │   # For standard processors that are defined here, `**cross_attention_kwargs` is e  │
│ ❱  267 │   │   return self.processor(                                                            │
│    268 │   │   │   self,                                                                         │
│    269 │   │   │   hidden_states,                                                                │
│    270 │   │   │   encoder_hidden_states=encoder_hidden_states,                                  │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/diffusers/models/attention_processor │
│ .py:696 in __call__                                                                              │
│                                                                                                  │
│    693 │   │   key = attn.head_to_batch_dim(key).contiguous()                                    │
│    694 │   │   value = attn.head_to_batch_dim(value).contiguous()                                │
│    695 │   │                                                                                     │
│ ❱  696 │   │   hidden_states = xformers.ops.memory_efficient_attention(                          │
│    697 │   │   │   query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=att  │
│    698 │   │   )                                                                                 │
│    699 │   │   hidden_states = hidden_states.to(query.dtype)                                     │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:192 in │
│ memory_efficient_attention                                                                       │
│                                                                                                  │
│   189 │   │   and options.                                                                       │
│   190 │   :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``                     │
│   191 │   """                                                                                    │
│ ❱ 192 │   return _memory_efficient_attention(                                                    │
│   193 │   │   Inputs(                                                                            │
│   194 │   │   │   query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale       │
│   195 │   │   ),                                                                                 │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:295 in │
│ _memory_efficient_attention                                                                      │
│                                                                                                  │
│   292 │   │   )                                                                                  │
│   293 │                                                                                          │
│   294 │   output_shape = inp.normalize_bmhk()                                                    │
│ ❱ 295 │   return _fMHA.apply(                                                                    │
│   296 │   │   op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale                 │
│   297 │   ).reshape(output_shape)                                                                │
│   298                                                                                            │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:41 in  │
│ forward                                                                                          │
│                                                                                                  │
│    38 │   │   op_fw = op[0] if op is not None else None                                          │
│    39 │   │   op_bw = op[1] if op is not None else None                                          │
│    40 │   │                                                                                      │
│ ❱  41 │   │   out, op_ctx = _memory_efficient_attention_forward_requires_grad(                   │
│    42 │   │   │   inp=inp, op=op_fw                                                              │
│    43 │   │   )                                                                                  │
│    44                                                                                            │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:323 in │
│ _memory_efficient_attention_forward_requires_grad                                                │
│                                                                                                  │
│   320 │   │   op = _dispatch_fw(inp)                                                             │
│   321 │   else:                                                                                  │
│   322 │   │   _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)    │
│ ❱ 323 │   out = op.apply(inp, needs_gradient=True)                                               │
│   324 │   assert out[1] is not None                                                              │
│   325 │   return (out[0].reshape(output_shape), out[1])                                          │
│   326                                                                                            │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/xformers/ops/fmha/cutlass.py:175 in  │
│ apply                                                                                            │
│                                                                                                  │
│   172 │   │   if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:                      │
│   173 │   │   │   raise NotImplementedError("Unsupported attn_bias type")                        │
│   174 │   │   seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(inp)                    │
│ ❱ 175 │   │   out, lse, rng_seed, rng_offset = cls.OPERATOR(                                     │
│   176 │   │   │   query=inp.query,                                                               │
│   177 │   │   │   key=inp.key,                                                                   │
│   178 │   │   │   value=inp.value,                                                               │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/torch/_ops.py:143 in __call__        │
│                                                                                                  │
│   140 │   │   # is still callable from JIT                                                       │
│   141 │   │   # We save the function ptr as the `op` attribute on                                │
│   142 │   │   # OpOverloadPacket to access it here.                                              │
│ ❱ 143 │   │   return self._op(*args, **kwargs or {})                                             │
│   144 │                                                                                          │
│   145 │   # TODO: use this to make a __dir__                                                     │
│   146 │   def overloads(self):                                                                   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Expected query.size(0) == key.size(0) to be true, but got false.  (Could this error message be improved?  If so, please
report an enhancement request to PyTorch.)
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /data/yoke/anaconda3/envs/hcpd/bin/accelerate:8 in <module>                                      │
│                                                                                                  │
│   5 from accelerate.commands.accelerate_cli import main                                          │
│   6 if __name__ == '__main__':                                                                   │
│   7 │   sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])                         │
│ ❱ 8 │   sys.exit(main())                                                                         │
│   9                                                                                              │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.p │
│ y:45 in main                                                                                     │
│                                                                                                  │
│   42 │   │   exit(1)                                                                             │
│   43 │                                                                                           │
│   44 │   # Run                                                                                   │
│ ❱ 45 │   args.func(args)                                                                         │
│   46                                                                                             │
│   47                                                                                             │
│   48 if __name__ == "__main__":                                                                  │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/accelerate/commands/launch.py:918 in │
│ launch_command                                                                                   │
│                                                                                                  │
│   915 │   elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMA   │
│   916 │   │   sagemaker_launcher(defaults, args)                                                 │
│   917 │   else:                                                                                  │
│ ❱ 918 │   │   simple_launcher(args)                                                              │
│   919                                                                                            │
│   920                                                                                            │
│   921 def main():                                                                                │
│                                                                                                  │
│ /data/yoke/anaconda3/envs/hcpd/lib/python3.10/site-packages/accelerate/commands/launch.py:580 in │
│ simple_launcher                                                                                  │
│                                                                                                  │
│   577 │   process.wait()                                                                         │
│   578 │   if process.returncode != 0:                                                            │
│   579 │   │   if not args.quiet:                                                                 │
│ ❱ 580 │   │   │   raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)    │
│   581 │   │   else:                                                                              │
│   582 │   │   │   sys.exit(1)                                                                    │
│   583                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
CalledProcessError: Command '['/data/yoke/anaconda3/envs/hcpd/bin/python', '-m', 'hcpdiff.train_ac_single', '--cfg',
'cfgs/train/examples/DreamArtist++2.yaml']' returned non-zero exit status 1.