google / prompt-to-prompt

Apache License 2.0
3.07k stars 285 forks source link

visualizing self attention map #85

Open dain5832 opened 6 months ago

dain5832 commented 6 months ago

Hi, I wonder how these figures were obtained. SVD on self-attention map would produce U, S, and V.T. How did you obtain the figures??

capture

ChiehYunChen commented 6 months ago

Hi @dain5832, I am not the author but I think they visualize the self-attention maps by the following code. I find this from their released code

def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], max_com=10, select: int = 0): attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res 2, res 2)) u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) images = [] for i in range(max_com): image = vh[i].reshape(res, res) image = image - image.min() image = 255 * image / image.max() image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) image = Image.fromarray(image).resize((256, 256)) image = np.array(image) images.append(image) ptp_utils.view_images(np.concatenate(images, axis=1))