castorini / daam

Diffusion attentive attribution maps for interpreting Stable Diffusion.
MIT License
669 stars 61 forks source link

Support for Prompt Embeddings Input Argument during Inference #48

Open chrisprasanna opened 1 year ago

chrisprasanna commented 1 year ago

This is more of a request, but would you be able to support using custom embeddings and negative embeddings as pipeline arguments? The reason I want to do this is so I can use prompt engineering techniques such as prompt weighting/emphasis, which aren't directly supported in diffusers. HuggingFace suggests using the compel library to generate your own text embeddings and then input those into the pipeline - https://huggingface.co/docs/diffusers/using-diffusers/weighted_prompts

However, when trying to use this with DAAM, I get the following error:

from diffusers import StableDiffusionPipeline
from compel import Compel

pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")

compel = Compel(tokenizer=pipeline.tokenizer, text_encoder=pipeline.text_encoder, truncate_long_prompts=False)

prompt = "a person"
negative_prompt = "bad art" 

conditioning = compel.build_conditioning_tensor(prompt)
negative_conditioning = compel.build_conditioning_tensor(negative_prompt)

[embeddings, negative_embeddings] = compel.pad_conditioning_tensors_to_same_length([conditioning, negative_conditioning])

with torch.autocast("cuda", dtype=torch.float16), torch.no_grad():
    with trace(pipeline) as tc:
        image = pipeline(
            prompt_embeds=embeddings, 
            negative_prompt_embeds=negative_embeddings, 
            height=512,
            width=512,
            num_images_per_prompt=1,
            num_inference_steps=35,
            guidance_scale=7.5,  
        ).images[0]

ERROR MESSAGE:

Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:12                                                                                   │
│                                                                                                  │
│    9                                                                                             │
│   10 with torch.autocast("cuda", dtype=torch.float16), torch.no_grad():                          │
│   11 │   with trace(pipeline) as tc:                                                             │
│ ❱ 12 │   │   image = pipeline(                                                                   │
│   13 │   │   │   prompt_embeds=embeddings,                                                       │
│   14 │   │   │   height=height,                                                                  │
│   15 │   │   │   width=width,                                                                    │
│                                                                                                  │
│ /opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/utils/_contextlib.py:115 in           │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /opt/conda/envs/pytorch/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeli │
│ ne_stable_diffusion.py:645 in __call__                                                           │
│                                                                                                  │
│   642 │   │   do_classifier_free_guidance = guidance_scale > 1.0                                 │
│   643 │   │                                                                                      │
│   644 │   │   # 3. Encode input prompt                                                           │
│ ❱ 645 │   │   prompt_embeds = self._encode_prompt(                                               │
│   646 │   │   │   prompt,                                                                        │
│   647 │   │   │   device,                                                                        │
│   648 │   │   │   num_images_per_prompt,                                                         │
│                                                                                                  │
│ /opt/conda/envs/pytorch/lib/python3.10/site-packages/daam/trace.py:146 in _hooked_encode_prompt  │
│                                                                                                  │
│   143 │   │   return image, has_nsfw                                                             │
│   144 │                                                                                          │
│   145 │   def _hooked_encode_prompt(hk_self, _: StableDiffusionPipeline, prompt: Union[str, Li   │
│ ❱ 146 │   │   if not isinstance(prompt, str) and len(prompt) > 1:                                │
│   147 │   │   │   raise ValueError('Only single prompt generation is supported for heat map co   │
│   148 │   │   elif not isinstance(prompt, str):                                                  │
│   149 │   │   │   last_prompt = prompt[0]                                                        │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: object of type 'NoneType' has no len()

Compel documentation - https://github.com/damian0815/compel

daemon commented 1 year ago

I'll look into it. Overall, DAAM's monkey-patching approach does make for brittleness.

shahariar-shibli commented 1 year ago

Hi, Thanks for your hard work. Will be of great help if custom embeddings are supported in the pipeline arguments.

shahariar-shibli commented 1 year ago

@chrisprasanna Hi, have you found any workaround regarding this?

s183898 commented 5 months ago

Also curious about this.

Although, after some thought, this might not be possible at all, since it requires going from text-embeddings back to text/prompt.