instantX-research / InstantID

InstantID: Zero-shot Identity-Preserving Generation in Seconds 🔥
https://instantid.github.io/
Apache License 2.0
11.12k stars 807 forks source link

Mask dimension mismatch in ip_adapter/attention_processor #115

Closed BXYMartin closed 9 months ago

BXYMartin commented 9 months ago

Running with main branch 98332df and get this error in ip_adapter/attention_processor:430

  File "/home/ubuntu/work-dir/InstantID/./ip_adapter/attention_processor.py", line 430, in forward
    ip_hidden_states = ip_hidden_states * mask
                       ~~~~~~~~~~~~~~~~~^~~~~~
RuntimeError: The size of tensor a (6400) must match the size of tensor b (9) at non-singleton dimension 1

Possible fix (not sure if this is correct or not)

diff --git a/ip_adapter/attention_processor.py b/ip_adapter/attention_processor.py
index 745143d..db1cd02 100644
--- a/ip_adapter/attention_processor.py
+++ b/ip_adapter/attention_processor.py
@@ -183,7 +183,7 @@ class IPAttnProcessor(nn.Module):
             region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
             if region_mask is not None:
                 h, w = region_mask.shape[:2]
-                ratio = (h * w / query.shape[1]) ** 0.5
+                ratio = (h * w / query.shape[2]) ** 0.5
                 mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
             else:
                 mask = torch.ones_like(ip_hidden_states)
@@ -422,7 +422,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
             region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
             if region_mask is not None:
                 h, w = region_mask.shape[:2]
-                ratio = (h * w / query.shape[1]) ** 0.5
+                ratio = (h * w / query.shape[2]) ** 0.5
                 mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
             else:
                 mask = torch.ones_like(ip_hidden_states)
ResearcherXman commented 9 months ago

Yes, it is a bug, we are working on it now. Please from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor at this moment.

ResearcherXman commented 9 months ago

Fixed. Please let us know you have further questions.