kousw / stable-diffusion-webui-daam

DAAM for Stable Diffusion Web UI
Other
163 stars 28 forks source link

Temporary solution for the issues of not generating heatmaps #36

Open alchemine opened 2 months ago

alchemine commented 2 months ago

Add BREAK BREAK BREAK BREAK !

image

grid-0005 grid_daam-0003


Description

Despite writing Attention texts for visualization. (comma separated), I encountered an problem where I couldn't see the heatmap.

First, the outermost code causing the problem is at https://github.com/kousw/stable-diffusion-webui-daam/blob/b23fb574bf691f0bdf503e5617a0b3578160c7a1/scripts/daam/trace.py#L363 specifically the third condition:

attn_slice.shape[-1] == hk_self.context_size

Both values should represent the sequence length (the number of input tokens).

However, when the number of input tokens is below a certain value (385—this value differs from what's shown in the WebUI), the left-hand side is fixed at 385, causing the condition not to be met. As a result, it creates an empty list, as shown in https://github.com/kousw/stable-diffusion-webui-daam/issues/33.

The problematic value of 385 comes from kwargs["context"] at https://github.com/kousw/stable-diffusion-webui-daam/blob/b23fb574bf691f0bdf503e5617a0b3578160c7a1/scripts/daam/trace.py#L40 Since this value is not generated by the DAAM extension but is received from stable-diffusion-stability-ai/ldm/models/diffusion/ddpm.py, it seems we need to understand the fundamental logic to resolve this issue.

Traceback (most recent call last):
  File "/workspace/generative/stable-diffusion-webui/modules/call_queue.py", line 74, in f
    res = list(func(*args, **kwargs))
  File "/workspace/generative/stable-diffusion-webui/modules/call_queue.py", line 53, in f
    res = func(*args, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/modules/call_queue.py", line 37, in f
    res = func(*args, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/modules/txt2img.py", line 109, in txt2img
    processed = processing.process_images(p)
  File "/workspace/generative/stable-diffusion-webui/modules/processing.py", line 847, in process_images
    res = process_images_inner(p)
  File "/workspace/generative/stable-diffusion-webui/modules/processing.py", line 988, in process_images_inner
    samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
  File "/workspace/generative/stable-diffusion-webui/modules/processing.py", line 1346, in sample
    samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
  File "/workspace/generative/stable-diffusion-webui/modules/sd_samplers_timesteps.py", line 159, in sample
    samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
  File "/workspace/generative/stable-diffusion-webui/modules/sd_samplers_common.py", line 272, in launch_sampling
    return func()
  File "/workspace/generative/stable-diffusion-webui/modules/sd_samplers_timesteps.py", line 159, in <lambda>
    samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
  File "/workspace/generative/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/modules/sd_samplers_timesteps_impl.py", line 64, in ddim_cfgpp
    e_t = model(x, timesteps[index].item() * s_in, **extra_args)
  File "/workspace/generative/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/modules/sd_samplers_cfg_denoiser.py", line 249, in forward
    x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
  File "/workspace/generative/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/modules/sd_samplers_timesteps.py", line 31, in forward
    return self.inner_model.apply_model(input, timesteps, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/modules/sd_hijack_utils.py", line 22, in <lambda>
    setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
  File "/workspace/generative/stable-diffusion-webui/modules/sd_hijack_utils.py", line 34, in __call__
    return self.__sub_func(self.__orig_func, *args, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/modules/sd_hijack_unet.py", line 50, in apply_model
    result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/modules/sd_hijack_utils.py", line 22, in <lambda>
    setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
  File "/workspace/generative/stable-diffusion-webui/modules/sd_hijack_utils.py", line 36, in __call__
    return self.__orig_func(*args, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/ddpm.py", line 858, in apply_model
    x_recon = self.model(x_noisy, t, **cond)
  File "/workspace/generative/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/ddpm.py", line 1335, in forward
    out = self.diffusion_model(x, t, context=cc)
  File "/workspace/generative/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/generative/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)

Therefore, as a temporary workaround, I used the BREAK keyword to ensure the input sequence length is sufficiently long. Of course, if the prompt is long enough, there's no need to use BREAK.

Since I'm not an expert in Stable Diffusion or attention mechanisms, there may have been incorrect explanations. I apologize in advance. 😂