Nota-NetsPresso / BK-SDM

A Compressed Stable Diffusion for Efficient Text-to-Image Generation [ECCV'24]
Other
238 stars 16 forks source link

Could the authors share the code of producting heat map of Figure.8? I am very appreciate your nice work and kind help. #55

Closed StormArcher closed 4 months ago

StormArcher commented 6 months ago

Could the authors share the code of producting heat map of Figure.8? I am very appreciate your nice work and kind help.

bokyeong1015 commented 4 months ago

Hi,

Our attention map analysis was conducted by @deepkyu (big thanks), and the code is based on:

(1) Install the package: pip install daam (2) Run an example code:

  from daam import trace, set_seed
  import torch
  from diffusers import StableDiffusionPipeline
  from matplotlib import pyplot as plt
  import torch

  pipe = StableDiffusionPipeline.from_pretrained(
      "nota-ai/bk-sdm-small", torch_dtype=torch.float16
  )
  pipe = pipe.to("cuda")

  prompt = "a golden vase with different flowers"
  gen = set_seed(0)  # for reproducibility

  word_heat_map = "vase"
  with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad():
      with trace(pipe) as tc:
          image = pipe(prompt, generator=gen)
          heat_map = tc.compute_global_heat_map()
          heat_map = heat_map.compute_word_heat_map(word_heat_map)
          heat_map.plot_overlay(image.images[0])
          plt.savefig(f"bk-sdm-small_{word_heat_map}.png")