aim-uofa / FreeCustom

[CVPR 2024] Official PyTorch implementation of FreeCustom: Tuning-Free Customized Image Generation for Multi-Concept Composition
https://aim-uofa.github.io/FreeCustom/
BSD 2-Clause "Simplified" License
109 stars 3 forks source link

RuntimeError: The size of tensor a (1024) must match the size of tensor b (3072) at non-singleton dimension 2 #11

Closed lyx-JuneSnow closed 3 months ago

lyx-JuneSnow commented 4 months ago
(lyx_FreeCustom) [zzy@localhost FreeCustom]$ python freecustom_stable_diffusion.py
/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.
  deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
config: {'mark': '', 'model_path': 'runwayml/stable-diffusion-v1-5', 'gpu': 0, 'start_step': 0, 'end_step': 50, 'layer_idx': [10, 11, 12, 13, 14, 15], 'seeds': [2], 'ref_image_infos': {'Data/DeepFashion_generate_test/image/face.jpg': 'face', 'Data/DeepFashion_generate_test/image/shirt.jpg': 'a short-sleeve shirt with cotton fabric and pure color patterns', 'Data/DeepFashion_generate_test/image/pants.jpg': 'pants'}, 'target_prompt': 'someone with a human head and face wearing a short-sleeve shirt with cotton fabric and pure color patterns and pants', 'use_null_ref_prompts': False, 'mask_weights': [3.0, 3.0, 3.0], 'negative_prompt': 'lowres, bad anatomy, text, error, cropped, worst quality, low quality, normal quality, jpeg artifacts, blurry', 'style_fidelity': 1}
Loading pipeline components...: 100%|█████████████████████| 7/7 [00:00<00:00, 13.57it/s]
/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:201: FutureWarning: The configuration file of this scheduler: DDIMScheduler {
  "_class_name": "DDIMScheduler",
  "_diffusers_version": "0.28.2",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "clip_sample": false,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "rescale_betas_zero_snr": false,
  "sample_max_value": 1.0,
  "set_alpha_to_one": false,
  "steps_offset": 0,
  "thresholding": false,
  "timestep_spacing": "leading",
  "trained_betas": null
}
 is outdated. `steps_offset` should be set to 1 instead of 0. Please make sure to update the config accordingly as leaving `steps_offset` might led to incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json` file
  deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
Seed set to 2
/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:283: FutureWarning: `_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.
  deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
  0%|                                                            | 0/50 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/data/zzy/lyx/FreeCustom/freecustom_stable_diffusion.py", line 105, in <module>
    images = model(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/data/zzy/lyx/FreeCustom/pipelines/pipeline_stable_diffusion_freecustom.py", line 129, in __call__
    noise_pred = self.unet(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1285, in forward
    sample = upsample_block(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 2551, in forward
    hidden_states = attn(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 440, in forward
    hidden_states = block(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/attention.py", line 329, in forward
    attn_output = self.attn1(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/lyx/FreeCustom/freecustom/hack_attention.py", line 46, in forward
    out = mrsa(
  File "/data/zzy/lyx/FreeCustom/freecustom/mrsa.py", line 38, in __call__
    out = self.mrsa_forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
  File "/data/zzy/lyx/FreeCustom/freecustom/mrsa.py", line 125, in mrsa_forward
    out_u_target = self.attn_batch(qu_o, ku_cat, vu_cat, None, None, is_cross, place_in_unet, num_heads, attn_batch_type='mrsa', **kwargs)
  File "/data/zzy/lyx/FreeCustom/freecustom/mrsa.py", line 66, in attn_batch
    sim_ref = sim_ref + ref_mask.masked_fill(ref_mask == 0, torch.finfo(sim.dtype).min)
RuntimeError: The size of tensor a (1024) must match the size of tensor b (3072) at non-singleton dimension 2
(lyx_FreeCustom) [zzy@localhost FreeCustom]$ python freecustom_stable_diffusion.py
/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.
  deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
config: {'mark': '', 'model_path': 'runwayml/stable-diffusion-v1-5', 'gpu': 0, 'start_step': 0, 'end_step': 50, 'layer_idx': [10, 11, 12, 13, 14, 15], 'seeds': [2], 'ref_image_infos': {'Data/DeepFashion_generate_test/image/face.jpg': 'face', 'Data/DeepFashion_generate_test/image/shirt.jpg': 'a short-sleeve shirt with cotton fabric and pure color patterns', 'Data/DeepFashion_generate_test/image/pants.jpg': 'pants'}, 'target_prompt': 'someone with a human head and face wearing a short-sleeve shirt with cotton fabric and pure color patterns and pants', 'use_null_ref_prompts': False, 'mask_weights': [3.0, 3.0, 3.0], 'negative_prompt': 'lowres, bad anatomy, text, error, cropped, worst quality, low quality, normal quality, jpeg artifacts, blurry', 'style_fidelity': 1}
Loading pipeline components...: 100%|█████████████████████| 7/7 [00:00<00:00, 13.53it/s]
/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:201: FutureWarning: The configuration file of this scheduler: DDIMScheduler {
  "_class_name": "DDIMScheduler",
  "_diffusers_version": "0.28.2",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "clip_sample": false,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "rescale_betas_zero_snr": false,
  "sample_max_value": 1.0,
  "set_alpha_to_one": false,
  "steps_offset": 0,
  "thresholding": false,
  "timestep_spacing": "leading",
  "trained_betas": null
}
 is outdated. `steps_offset` should be set to 1 instead of 0. Please make sure to update the config accordingly as leaving `steps_offset` might led to incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json` file
  deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
Seed set to 2
/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:283: FutureWarning: `_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.
  deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
  0%|                                                            | 0/50 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/data/zzy/lyx/FreeCustom/freecustom_stable_diffusion.py", line 105, in <module>
    images = model(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/data/zzy/lyx/FreeCustom/pipelines/pipeline_stable_diffusion_freecustom.py", line 129, in __call__
    noise_pred = self.unet(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1285, in forward
    sample = upsample_block(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 2551, in forward
    hidden_states = attn(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 440, in forward
    hidden_states = block(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/attention.py", line 329, in forward
    attn_output = self.attn1(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/lyx/FreeCustom/freecustom/hack_attention.py", line 46, in forward
    out = mrsa(
  File "/data/zzy/lyx/FreeCustom/freecustom/mrsa.py", line 38, in __call__
    out = self.mrsa_forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
  File "/data/zzy/lyx/FreeCustom/freecustom/mrsa.py", line 125, in mrsa_forward
    out_u_target = self.attn_batch(qu_o, ku_cat, vu_cat, None, None, is_cross, place_in_unet, num_heads, attn_batch_type='mrsa', **kwargs)
  File "/data/zzy/lyx/FreeCustom/freecustom/mrsa.py", line 66, in attn_batch
    sim_ref = sim_ref + ref_mask.masked_fill(ref_mask == 0, torch.finfo(sim.dtype).min)
RuntimeError: The size of tensor a (1024) must match the size of tensor b (3072) at non-singleton dimension 2

以上是我的运行报错 然后我在mrsa.py中添加了如下的打印信息

# 在调用 attn_batch 之前,打印张量的形状
print(f"qu_o shape: {qu_o.shape}")  # (batch_size, num_heads, sequence_length, feature_dimension)
print(f"ku_cat shape: {ku_cat.shape}")  # (batch_size, num_heads, sequence_length, feature_dimension)
print(f"vu_cat shape: {vu_cat.shape}")  # (batch_size, num_heads, sequence_length, feature_dimension)
# 确保形状匹配
if ku_cat.shape[2] != qu_o.shape[2] or vu_cat.shape[2] != qu_o.shape[2]:
    raise ValueError(f"Shape mismatch: qu_o.shape={qu_o.shape}, ku_cat.shape={ku_cat.shape}, vu_cat.shape={vu_cat.shape}")

结果为

q shape: torch.Size([64, 4096, 40])
k shape: torch.Size([64, 4096, 40])
v shape: torch.Size([64, 4096, 40])
sim shape: torch.Size([64, 4096, 4096])
attn shape: torch.Size([64, 4096, 4096])
q shape: torch.Size([64, 4096, 40])
k shape: torch.Size([64, 77, 40])
v shape: torch.Size([64, 77, 40])
sim shape: torch.Size([64, 4096, 77])
attn shape: torch.Size([64, 4096, 77])
q shape: torch.Size([64, 4096, 40])
k shape: torch.Size([64, 4096, 40])
v shape: torch.Size([64, 4096, 40])
sim shape: torch.Size([64, 4096, 4096])
attn shape: torch.Size([64, 4096, 4096])
q shape: torch.Size([64, 4096, 40])
k shape: torch.Size([64, 77, 40])
v shape: torch.Size([64, 77, 40])
sim shape: torch.Size([64, 4096, 77])
attn shape: torch.Size([64, 4096, 77])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 1024, 80])
v shape: torch.Size([64, 1024, 80])
sim shape: torch.Size([64, 1024, 1024])
attn shape: torch.Size([64, 1024, 1024])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 77, 80])
v shape: torch.Size([64, 77, 80])
sim shape: torch.Size([64, 1024, 77])
attn shape: torch.Size([64, 1024, 77])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 1024, 80])
v shape: torch.Size([64, 1024, 80])
sim shape: torch.Size([64, 1024, 1024])
attn shape: torch.Size([64, 1024, 1024])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 77, 80])
v shape: torch.Size([64, 77, 80])
sim shape: torch.Size([64, 1024, 77])
attn shape: torch.Size([64, 1024, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 64, 160])
k shape: torch.Size([64, 64, 160])
v shape: torch.Size([64, 64, 160])
sim shape: torch.Size([64, 64, 64])
attn shape: torch.Size([64, 64, 64])
q shape: torch.Size([64, 64, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 64, 77])
attn shape: torch.Size([64, 64, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 1024, 80])
v shape: torch.Size([64, 1024, 80])
sim shape: torch.Size([64, 1024, 1024])
attn shape: torch.Size([64, 1024, 1024])
qu_o shape: torch.Size([8, 1024, 80])
ku_cat shape: torch.Size([8, 4096, 80])
vu_cat shape: torch.Size([8, 4096, 80])

qu_o 和 ku_cat/vu_cat 的形状不匹配。特别是 qu_o 的形状是 [8, 1024, 80],而 ku_cat 和 vu_cat 的形状是 [8, 4096, 80]。 做了如下调整

# 确保形状匹配
if ku_cat.shape[2] != qu_o.shape[2] or vu_cat.shape[2] != qu_o.shape[2]:
    print(f"Reshaping ku_cat and vu_cat to match qu_o")
    ku_cat = ku_cat.view(qu_o.shape[0], qu_o.shape[1], qu_o.shape[2])
    vu_cat = vu_cat.view(qu_o.shape[0], qu_o.shape[1], qu_o.shape[2])

仍然是一样的报错 于是修改mrsa.py中

def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
    B = q.shape[0] // num_heads
    H = W = int(np.sqrt(q.shape[1]))
    q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) 
    k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
    v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)

    sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")

    if kwargs.get("attn_batch_type") == 'mrsa':
        sim_own, sim_refs = sim[..., :H*W], sim[..., H*W:]
        sim_or = [sim_own]
        for i, (ref_mask, mask_weight) in enumerate(zip(self.ref_masks, self.mask_weights)):
            ref_mask = self.get_ref_mask(ref_mask, mask_weight, H, W)
            ref_mask = ref_mask.view(1, -1)  # 调整 ref_mask 的形状
            print(f"sim_ref shape before: {sim_refs[..., H*W*i: H*W*(i+1)].shape}")
            print(f"ref_mask shape: {ref_mask.shape}")
            sim_ref = sim_refs[..., H*W*i: H*W*(i+1)]
            sim_ref = sim_ref + ref_mask.masked_fill(ref_mask == 0, torch.finfo(sim.dtype).min)
            print(f"sim_ref shape after: {sim_ref.shape}")
            sim_or.append(sim_ref)
        sim = torch.cat(sim_or, dim=-1)
    attn = sim.softmax(-1)

    # viz attention map within MRSA module
    if self.viz_cfg.viz_attention_map and \
        kwargs.get("attn_batch_type") == 'mrsa' and \
        self.cur_step in self.viz_cfg.viz_map_at_step and \
        self.cur_att_layer // 2 in self.viz_cfg.viz_map_at_layer:
        visualize_attention_map(attn, self.viz_cfg, self.cur_step, self.cur_att_layer//2)

    # viz feature correspondence within MRSA module
    if self.viz_cfg.viz_feature_correspondence and \
        kwargs.get("attn_batch_type") == 'mrsa' and \
        self.cur_step in self.viz_cfg.viz_corr_at_step and \
        self.cur_att_layer // 2 in self.viz_cfg.viz_corr_at_layer:
        visualize_correspondence(self.viz_cfg, attn, self.cur_step, self.cur_att_layer//2)

    if len(attn) == 2 * len(v):
        v = torch.cat([v] * 2)
    out = torch.einsum("h i j, h j d -> h i d", attn, v)
    out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
    return out

打印出

sim_ref shape before: torch.Size([8, 1024, 1024])
ref_mask shape: torch.Size([1, 3072])

sim_ref 和 ref_mask 的形状不匹配。sim_ref 形状为 [8, 1024, 1024],而 ref_mask 的形状为 [1, 3072] 继续修改mrsa.py为

def get_ref_mask(self, ref_mask, mask_weight, H, W):
    ref_mask = ref_mask.float() * mask_weight
    ref_mask = F.interpolate(ref_mask, (H, W))
    ref_mask = ref_mask.flatten()
    ref_mask = ref_mask.unsqueeze(0)  # 添加这个确保 ref_mask 形状为 [1, H*W]
    return ref_mask

以及

def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
    B = q.shape[0] // num_heads
    H = W = int(np.sqrt(q.shape[1]))
    q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) 
    k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
    v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)

    sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")

    if kwargs.get("attn_batch_type") == 'mrsa':
        sim_own, sim_refs = sim[..., :H*W], sim[..., H*W:]
        sim_or = [sim_own]
        for i, (ref_mask, mask_weight) in enumerate(zip(self.ref_masks, self.mask_weights)):
            ref_mask = self.get_ref_mask(ref_mask, mask_weight, H, W)
            ref_mask = ref_mask.repeat(sim_refs.size(0), 1)  # 确保 ref_mask 形状正确
            print(f"sim_ref shape before: {sim_refs[..., H*W*i: H*W*(i+1)].shape}")
            print(f"ref_mask shape: {ref_mask.shape}")
            sim_ref = sim_refs[..., H*W*i: H*W*(i+1)]
            sim_ref = sim_ref + ref_mask.masked_fill(ref_mask == 0, torch.finfo(sim.dtype).min)
            print(f"sim_ref shape after: {sim_ref.shape}")
            sim_or.append(sim_ref)
        sim = torch.cat(sim_or, dim=-1)
    attn = sim.softmax(-1)

    # viz attention map within MRSA module
    if self.viz_cfg.viz_attention_map and \
        kwargs.get("attn_batch_type") == 'mrsa' and \
        self.cur_step in self.viz_cfg.viz_map_at_step and \
        self.cur_att_layer // 2 in self.viz_cfg.viz_map_at_layer:
        visualize_attention_map(attn, self.viz_cfg, self.cur_step, self.cur_att_layer//2)

    # viz feature correspondence within MRSA module
    if self.viz_cfg.viz_feature_correspondence and \
        kwargs.get("attn_batch_type") == 'mrsa' and \
        self.cur_step in self.viz_cfg.viz_corr_at_step and \
        self.cur_att_layer // 2 in self.viz_cfg.viz_corr_at_layer:
        visualize_correspondence(self.viz_cfg, attn, self.cur_step, self.cur_att_layer//2)

    if len(attn) == 2 * len(v):
        v = torch.cat([v] * 2)
    out = torch.einsum("h i j, h j d -> h i d", attn, v)
    out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
    return out

结果是

q shape: torch.Size([64, 4096, 40])
k shape: torch.Size([64, 4096, 40])
v shape: torch.Size([64, 4096, 40])
sim shape: torch.Size([64, 4096, 4096])
attn shape: torch.Size([64, 4096, 4096])
q shape: torch.Size([64, 4096, 40])
k shape: torch.Size([64, 77, 40])
v shape: torch.Size([64, 77, 40])
sim shape: torch.Size([64, 4096, 77])
attn shape: torch.Size([64, 4096, 77])
q shape: torch.Size([64, 4096, 40])
k shape: torch.Size([64, 4096, 40])
v shape: torch.Size([64, 4096, 40])
sim shape: torch.Size([64, 4096, 4096])
attn shape: torch.Size([64, 4096, 4096])
q shape: torch.Size([64, 4096, 40])
k shape: torch.Size([64, 77, 40])
v shape: torch.Size([64, 77, 40])
sim shape: torch.Size([64, 4096, 77])
attn shape: torch.Size([64, 4096, 77])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 1024, 80])
v shape: torch.Size([64, 1024, 80])
sim shape: torch.Size([64, 1024, 1024])
attn shape: torch.Size([64, 1024, 1024])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 77, 80])
v shape: torch.Size([64, 77, 80])
sim shape: torch.Size([64, 1024, 77])
attn shape: torch.Size([64, 1024, 77])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 1024, 80])
v shape: torch.Size([64, 1024, 80])
sim shape: torch.Size([64, 1024, 1024])
attn shape: torch.Size([64, 1024, 1024])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 77, 80])
v shape: torch.Size([64, 77, 80])
sim shape: torch.Size([64, 1024, 77])
attn shape: torch.Size([64, 1024, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 64, 160])
k shape: torch.Size([64, 64, 160])
v shape: torch.Size([64, 64, 160])
sim shape: torch.Size([64, 64, 64])
attn shape: torch.Size([64, 64, 64])
q shape: torch.Size([64, 64, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 64, 77])
attn shape: torch.Size([64, 64, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 1024, 80])
v shape: torch.Size([64, 1024, 80])
sim shape: torch.Size([64, 1024, 1024])
attn shape: torch.Size([64, 1024, 1024])
qu_o shape: torch.Size([8, 1024, 80])
ku_cat shape: torch.Size([8, 4096, 80])
vu_cat shape: torch.Size([8, 4096, 80])
sim_ref shape before: torch.Size([8, 1024, 1024])
ref_mask shape: torch.Size([8, 3072])

继续改为

def get_ref_mask(self, ref_mask, mask_weight, H, W):
    ref_mask = ref_mask.float() * mask_weight
    ref_mask = F.interpolate(ref_mask, (H, W))
    ref_mask = ref_mask.flatten(start_dim=1)  # 展开后保持 batch 维度
    return ref_mask

def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
    B = q.shape[0] // num_heads
    H = W = int(np.sqrt(q.shape[1]))
    q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads) 
    k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
    v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)

    sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")

    if kwargs.get("attn_batch_type") == 'mrsa':
        sim_own, sim_refs = sim[..., :H*W], sim[..., H*W:]
        sim_or = [sim_own]
        for i, (ref_mask, mask_weight) in enumerate(zip(self.ref_masks, self.mask_weights)):
            ref_mask = self.get_ref_mask(ref_mask, mask_weight, H, W)
            ref_mask = ref_mask.view(sim_refs.size(0), -1)  # 确保 ref_mask 形状正确
            print(f"sim_ref shape before: {sim_refs[..., H*W*i: H*W*(i+1)].shape}")
            print(f"ref_mask shape: {ref_mask.shape}")
            sim_ref = sim_refs[..., H*W*i: H*W*(i+1)]
            sim_ref = sim_ref + ref_mask.masked_fill(ref_mask == 0, torch.finfo(sim.dtype).min)
            print(f"sim_ref shape after: {sim_ref.shape}")
            sim_or.append(sim_ref)
        sim = torch.cat(sim_or, dim=-1)
    attn = sim.softmax(-1)

    # viz attention map within MRSA module
    if self.viz_cfg.viz_attention_map and \
        kwargs.get("attn_batch_type") == 'mrsa' and \
        self.cur_step in self.viz_cfg.viz_map_at_step and \
        self.cur_att_layer // 2 in self.viz_cfg.viz_map_at_layer:
        visualize_attention_map(attn, self.viz_cfg, self.cur_step, self.cur_att_layer//2)

    # viz feature correspondence within MRSA module
    if self.viz_cfg.viz_feature_correspondence and \
        kwargs.get("attn_batch_type") == 'mrsa' and \
        self.cur_step in self.viz_cfg.viz_corr_at_step and \
        self.cur_att_layer // 2 in self.viz_cfg.viz_corr_at_layer:
        visualize_correspondence(self.viz_cfg, attn, self.cur_step, self.cur_att_layer//2)

    if len(attn) == 2 * len(v):
        v = torch.cat([v] * 2)
    out = torch.einsum("h i j, h j d -> h i d", attn, v)
    out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
    return out

结果为

sim_ref shape before: torch.Size([8, 1024, 1024])
ref_mask shape: torch.Size([8, 384])

继续改为

def get_ref_mask(self, ref_mask, mask_weight, H, W):
    ref_mask = ref_mask.float() * mask_weight
    ref_mask = F.interpolate(ref_mask.unsqueeze(0), (H, W)).squeeze(0)
    ref_mask = ref_mask.flatten(start_dim=1)  # 展开后保持 batch 维度
    return ref_mask

报错变为

(lyx_FreeCustom) [zzy@localhost FreeCustom]$ python freecustom_stable_diffusion.py
/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: Transformer2DModelOutput is deprecated and will be removed in version 1.0.0. Importing Transformer2DModelOutput from diffusers.models.transformer_2d is deprecated and this will be removed in a future version. Please use from diffusers.models.modeling_outputs import Transformer2DModelOutput, instead.
  deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
config: {'mark': '', 'model_path': 'runwayml/stable-diffusion-v1-5', 'gpu': 0, 'start_step': 0, 'end_step': 50, 'layer_idx': [10, 11, 12, 13, 14, 15], 'seeds': [2], 'ref_image_infos': {'Data/DeepFashion_generate_test/image/face.jpg': 'face', 'Data/DeepFashion_generate_test/image/shirt.jpg': 'a short-sleeve shirt with cotton fabric and pure color patterns', 'Data/DeepFashion_generate_test/image/pants.jpg': 'pants'}, 'target_prompt': 'someone with a human head and face wearing a short-sleeve shirt with cotton fabric and pure color patterns and pants', 'use_null_ref_prompts': False, 'mask_weights': [3.0, 3.0, 3.0], 'negative_prompt': 'lowres, bad anatomy, text, error, cropped, worst quality, low quality, normal quality, jpeg artifacts, blurry', 'style_fidelity': 1}
Loading pipeline components...: 100%|█████████████████████| 7/7 [00:00<00:00, 13.61it/s]
/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:201: FutureWarning: The configuration file of this scheduler: DDIMScheduler {
  "_class_name": "DDIMScheduler",
  "_diffusers_version": "0.28.2",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "clip_sample": false,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "rescale_betas_zero_snr": false,
  "sample_max_value": 1.0,
  "set_alpha_to_one": false,
  "steps_offset": 0,
  "thresholding": false,
  "timestep_spacing": "leading",
  "trained_betas": null
}
 is outdated. steps_offset should be set to 1 instead of 0. Please make sure to update the config accordingly as leaving steps_offset might led to incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for the scheduler/scheduler_config.json file
  deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
Data/DeepFashion_generate_test/image/pants.jpg
Seed set to 2
/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:283: FutureWarning: _encode_prompt() is deprecated and it will be removed in a future version. Use encode_prompt() instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.
  deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
  0%|                                                            | 0/50 [00:00<?, ?it/s]q shape: torch.Size([64, 4096, 40])
k shape: torch.Size([64, 4096, 40])
v shape: torch.Size([64, 4096, 40])
sim shape: torch.Size([64, 4096, 4096])
attn shape: torch.Size([64, 4096, 4096])
q shape: torch.Size([64, 4096, 40])
k shape: torch.Size([64, 77, 40])
v shape: torch.Size([64, 77, 40])
sim shape: torch.Size([64, 4096, 77])
attn shape: torch.Size([64, 4096, 77])
q shape: torch.Size([64, 4096, 40])
k shape: torch.Size([64, 4096, 40])
v shape: torch.Size([64, 4096, 40])
sim shape: torch.Size([64, 4096, 4096])
attn shape: torch.Size([64, 4096, 4096])
q shape: torch.Size([64, 4096, 40])
k shape: torch.Size([64, 77, 40])
v shape: torch.Size([64, 77, 40])
sim shape: torch.Size([64, 4096, 77])
attn shape: torch.Size([64, 4096, 77])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 1024, 80])
v shape: torch.Size([64, 1024, 80])
sim shape: torch.Size([64, 1024, 1024])
attn shape: torch.Size([64, 1024, 1024])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 77, 80])
v shape: torch.Size([64, 77, 80])
sim shape: torch.Size([64, 1024, 77])
attn shape: torch.Size([64, 1024, 77])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 1024, 80])
v shape: torch.Size([64, 1024, 80])
sim shape: torch.Size([64, 1024, 1024])
attn shape: torch.Size([64, 1024, 1024])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 77, 80])
v shape: torch.Size([64, 77, 80])
sim shape: torch.Size([64, 1024, 77])
attn shape: torch.Size([64, 1024, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 64, 160])
k shape: torch.Size([64, 64, 160])
v shape: torch.Size([64, 64, 160])
sim shape: torch.Size([64, 64, 64])
attn shape: torch.Size([64, 64, 64])
q shape: torch.Size([64, 64, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 64, 77])
attn shape: torch.Size([64, 64, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 256, 160])
v shape: torch.Size([64, 256, 160])
sim shape: torch.Size([64, 256, 256])
attn shape: torch.Size([64, 256, 256])
q shape: torch.Size([64, 256, 160])
k shape: torch.Size([64, 77, 160])
v shape: torch.Size([64, 77, 160])
sim shape: torch.Size([64, 256, 77])
attn shape: torch.Size([64, 256, 77])
q shape: torch.Size([64, 1024, 80])
k shape: torch.Size([64, 1024, 80])
v shape: torch.Size([64, 1024, 80])
sim shape: torch.Size([64, 1024, 1024])
attn shape: torch.Size([64, 1024, 1024])
qu_o shape: torch.Size([8, 1024, 80])
ku_cat shape: torch.Size([8, 4096, 80])
vu_cat shape: torch.Size([8, 4096, 80])
  0%|                                                            | 0/50 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/data/zzy/lyx/FreeCustom/freecustom_stable_diffusion.py", line 106, in <module>
    images = model(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/data/zzy/lyx/FreeCustom/pipelines/pipeline_stable_diffusion_freecustom.py", line 129, in __call__
    noise_pred = self.unet(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1285, in forward
    sample = upsample_block(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 2551, in forward
    hidden_states = attn(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 440, in forward
    hidden_states = block(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/diffusers/models/attention.py", line 329, in forward
    attn_output = self.attn1(
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/zzy/lyx/FreeCustom/freecustom/hack_attention.py", line 46, in forward
    out = mrsa(
  File "/data/zzy/lyx/FreeCustom/freecustom/mrsa.py", line 37, in __call__
    out = self.mrsa_forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
  File "/data/zzy/lyx/FreeCustom/freecustom/mrsa.py", line 152, in mrsa_forward
    out_u_target = self.attn_batch(qu_o, ku_cat, vu_cat, None, None, is_cross, place_in_unet, num_heads, attn_batch_type='mrsa', **kwargs)
  File "/data/zzy/lyx/FreeCustom/freecustom/mrsa.py", line 63, in attn_batch
    ref_mask = self.get_ref_mask(ref_mask, mask_weight, H, W)
  File "/data/zzy/lyx/FreeCustom/freecustom/mrsa.py", line 46, in get_ref_mask
    ref_mask = F.interpolate(ref_mask.unsqueeze(0), (H, W)).squeeze(0)
  File "/data/zzy/anaconda3/envs/lyx_FreeCustom/lib/python3.10/site-packages/torch/nn/functional.py", line 3961, in interpolate
    raise ValueError(
ValueError: Input and output must have the same number of spatial dimensions, but got input with spatial dimensions of [3, 128, 128] and output size of (32, 32). Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size in (o1, o2, ...,oK) format.
lyx-JuneSnow commented 3 months ago

https://drive.google.com/file/d/1_Wur1flNbv_smSDXE2icahxeDJSiSU1L/view?usp=sharing 这是我的数据

dingangui commented 3 months ago

qu_o 和 ku_cat/vu_cat 的形状不匹配是正常的,qu_o 就只是一个 batch 的数据,而 ku_cat 是多张图片拼接的,这两个变量能进行矩阵乘法。

问题一开始发生在 'sim_ref = sim_ref + ref_mask.masked_fill(ref_mask == 0, torch.finfo(sim.dtype).min)' 这里,从打印信息来看,应该是 ref_mask 出问题了,你的 mask 文件是不是以灰度文件保存的?mask保存的时候用单通道的,不要用 rgb 格式保存

dingangui commented 3 months ago

有两个建议可能可以帮助你排查问题。第一,你为了解决问题,对代码作出修改的部分都撤回,使用原始代码;第二,跑一遍我提供的示例,在最开始报错的位置打上断点,打印一下各个变量的 shape,然后和你的示例对比一下,看看是不是 ref_mask 的 shape 不一致,也可以在读取 mask 的地方(https://github.com/aim-uofa/FreeCustom/blob/main/freecustom_stable_diffusion.py#L51) 打上断点看看你的 mask 文件是不是和我的格式不一致

lyx-JuneSnow commented 3 months ago

有两个建议可能可以帮助你排查问题。第一,你为了解决问题,对代码作出修改的部分都撤回,使用原始代码;第二,跑一遍我提供的示例,在最开始报错的位置打上断点,打印一下各个变量的 shape,然后和你的示例对比一下,看看是不是 ref_mask 的 shape 不一致,也可以在读取 mask 的地方(https://github.com/aim-uofa/FreeCustom/blob/main/freecustom_stable_diffusion.py#L51) 打上断点看看你的 mask 文件是不是和我的格式不一致

非常感谢您提供的建议,我在读取ref_mask的地方打了断点输出了ref_mask的格式,我发现的确是因为我的mask通道数是3而导致了bug。 在freecustom_stable_diffusion.py的第51行后我添加了代码ref_mask = ref_mask[:, :1, :, :]解决了这个问题。