Stability-AI / stablediffusion

High-Resolution Image Synthesis with Latent Diffusion Models
MIT License
39.32k stars 5.07k forks source link

attn_mask dtype error #274

Open Ishihara-Masabumi opened 1 year ago

Ishihara-Masabumi commented 1 year ago

attn_mask dtype error occurred as follows.

$ python3 scripts/txt2img.py --prompt "a professional photograph of an astronaut riding a horse" --ckpt "./768-v-ema.ckpt"  --config configs/stable-diffusion/v2-inference-v.yaml --H 768 --W 768
/home/dl/.local/lib/python3.8/site-packages/jupyter_client/__init__.py:23: UserWarning: Could not import submodules
  warnings.warn("Could not import submodules")
Global seed set to 42
Loading model from ./768-v-ema.ckpt
Global Step: 140000
LatentDiffusion: Running in v-prediction mode
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
DiffusionWrapper has 865.91 M params.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...
Downloading: "https://github.com/DagnyT/hardnet/raw/master/pretrained/train_liberty_with_aug/checkpoint_liberty_with_aug.pth" to /home/dl/.cache/torch/hub/checkpoints/checkpoint_liberty_with_aug.pth
100%|███████████████████████████████████████████████████████████████████| 5.10M/5.10M [00:00<00:00, 56.2MB/s]
Downloading (…)ip_pytorch_model.bin: 100%|██████████████████████████████| 3.94G/3.94G [01:02<00:00, 63.1MB/s]
Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...
data:   0%|                                                                            | 0/1 [00:00<?, ?it/s]
Sampling:   0%|                                                                        | 0/3 [00:00<?, ?it/s]
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/dl/StableDiffusion0/stablediffusion/scripts/txt2img.py:388 in <module>                     │
│                                                                                                  │
│   385                                                                                            │
│   386 if __name__ == "__main__":                                                                 │
│   387 │   opt = parse_args()                                                                     │
│ ❱ 388 │   main(opt)                                                                              │
│   389                                                                                            │
│                                                                                                  │
│ /home/dl/StableDiffusion0/stablediffusion/scripts/txt2img.py:342 in main                         │
│                                                                                                  │
│   339 │   │   │   │   for prompts in tqdm(data, desc="data"):                                    │
│   340 │   │   │   │   │   uc = None                                                              │
│   341 │   │   │   │   │   if opt.scale != 1.0:                                                   │
│ ❱ 342 │   │   │   │   │   │   uc = model.get_learned_conditioning(batch_size * [""])             │
│   343 │   │   │   │   │   if isinstance(prompts, tuple):                                         │
│   344 │   │   │   │   │   │   prompts = list(prompts)                                            │
│   345 │   │   │   │   │   c = model.get_learned_conditioning(prompts)                            │
│                                                                                                  │
│ /home/dl/StableDiffusion0/stablediffusion/ldm/models/diffusion/ddpm.py:665 in                    │
│ get_learned_conditioning                                                                         │
│                                                                                                  │
│    662 │   def get_learned_conditioning(self, c):                                                │
│    663 │   │   if self.cond_stage_forward is None:                                               │
│    664 │   │   │   if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_mod  │
│ ❱  665 │   │   │   │   c = self.cond_stage_model.encode(c)                                       │
│    666 │   │   │   │   if isinstance(c, DiagonalGaussianDistribution):                           │
│    667 │   │   │   │   │   c = c.mode()                                                          │
│    668 │   │   │   else:                                                                         │
│                                                                                                  │
│ /home/dl/StableDiffusion0/stablediffusion/ldm/modules/encoders/modules.py:236 in encode          │
│                                                                                                  │
│   233 │   │   return x                                                                           │
│   234 │                                                                                          │
│   235 │   def encode(self, text):                                                                │
│ ❱ 236 │   │   return self(text)                                                                  │
│   237                                                                                            │
│   238                                                                                            │
│   239 class FrozenOpenCLIPImageEmbedder(AbstractEncoder):                                        │
│                                                                                                  │
│ /home/dl/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl        │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/dl/StableDiffusion0/stablediffusion/ldm/modules/encoders/modules.py:213 in forward         │
│                                                                                                  │
│   210 │                                                                                          │
│   211 │   def forward(self, text):                                                               │
│   212 │   │   tokens = open_clip.tokenize(text)                                                  │
│ ❱ 213 │   │   z = self.encode_with_transformer(tokens.to(self.device))                           │
│   214 │   │   return z                                                                           │
│   215 │                                                                                          │
│   216 │   def encode_with_transformer(self, text):                                               │
│                                                                                                  │
│ /home/dl/StableDiffusion0/stablediffusion/ldm/modules/encoders/modules.py:220 in                 │
│ encode_with_transformer                                                                          │
│                                                                                                  │
│   217 │   │   x = self.model.token_embedding(text)  # [batch_size, n_ctx, d_model]               │
│   218 │   │   x = x + self.model.positional_embedding                                            │
│   219 │   │   x = x.permute(1, 0, 2)  # NLD -> LND                                               │
│ ❱ 220 │   │   x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)               │
│   221 │   │   x = x.permute(1, 0, 2)  # LND -> NLD                                               │
│   222 │   │   x = self.model.ln_final(x)                                                         │
│   223 │   │   return x                                                                           │
│                                                                                                  │
│ /home/dl/StableDiffusion0/stablediffusion/ldm/modules/encoders/modules.py:232 in                 │
│ text_transformer_forward                                                                         │
│                                                                                                  │
│   229 │   │   │   if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(   │
│   230 │   │   │   │   x = checkpoint(r, x, attn_mask)                                            │
│   231 │   │   │   else:                                                                          │
│ ❱ 232 │   │   │   │   x = r(x, attn_mask=attn_mask)                                              │
│   233 │   │   return x                                                                           │
│   234 │                                                                                          │
│   235 │   def encode(self, text):                                                                │
│                                                                                                  │
│ /home/dl/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl        │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/dl/.local/lib/python3.8/site-packages/open_clip/transformer.py:154 in forward              │
│                                                                                                  │
│   151 │   │   return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]              │
│   152 │                                                                                          │
│   153 │   def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):          │
│ ❱ 154 │   │   x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))               │
│   155 │   │   x = x + self.ls_2(self.mlp(self.ln_2(x)))                                          │
│   156 │   │   return x                                                                           │
│   157                                                                                            │
│                                                                                                  │
│ /home/dl/.local/lib/python3.8/site-packages/open_clip/transformer.py:151 in attention            │
│                                                                                                  │
│   148 │                                                                                          │
│   149 │   def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):        │
│   150 │   │   attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None               │
│ ❱ 151 │   │   return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]              │
│   152 │                                                                                          │
│   153 │   def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):          │
│   154 │   │   x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))               │
│                                                                                                  │
│ /home/dl/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1501 in _call_impl        │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/dl/.local/lib/python3.8/site-packages/torch/nn/modules/activation.py:1205 in forward       │
│                                                                                                  │
│   1202 │   │   │   │   average_attn_weights=average_attn_weights,                                │
│   1203 │   │   │   │   is_causal=is_causal)                                                      │
│   1204 │   │   else:                                                                             │
│ ❱ 1205 │   │   │   attn_output, attn_output_weights = F.multi_head_attention_forward(            │
│   1206 │   │   │   │   query, key, value, self.embed_dim, self.num_heads,                        │
│   1207 │   │   │   │   self.in_proj_weight, self.in_proj_bias,                                   │
│   1208 │   │   │   │   self.bias_k, self.bias_v, self.add_zero_attn,                             │
│                                                                                                  │
│ /home/dl/.local/lib/python3.8/site-packages/torch/nn/functional.py:5373 in                       │
│ multi_head_attention_forward                                                                     │
│                                                                                                  │
│   5370 │   │   k = k.view(bsz, num_heads, src_len, head_dim)                                     │
│   5371 │   │   v = v.view(bsz, num_heads, src_len, head_dim)                                     │
│   5372 │   │                                                                                     │
│ ❱ 5373 │   │   attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_cau  │
│   5374 │   │   attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, e  │
│   5375 │   │                                                                                     │
│   5376 │   │   attn_output = linear(attn_output, out_proj_weight, out_proj_bias)                 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: float and
query.dtype: c10::BFloat16 instead.
changeonPark commented 1 year ago

me too.. how to solve this problem? plz help me..

changeonPark commented 1 year ago

@Ishihara-Masabumi https://github.com/Stability-AI/stablediffusion/issues/203#issuecomment-1492968988 <<<<

97XingChen commented 1 year ago

add this --device cuda