alibaba / EasyNLP

EasyNLP: A Comprehensive and Easy-to-use NLP Toolkit
Apache License 2.0
2.03k stars 250 forks source link

Visualizing self attention #355

Closed dain5832 closed 5 months ago

dain5832 commented 5 months ago

Hi, I was trying to draw a self-attention map just like fig3 in the paper, and came up with a few questions. Saying I have a 1024 x 1024 attention map, svd will return u, s, v.T. Then, what does it mean by top-6 component after SVD? Does top 0 map correspond to u @ s[:, 0] @ v.T[0, :]??

image

Bingyan-Liu commented 5 months ago

Use the following self-attention shown code from prompt-to-prompt and adapt it to your existing code. def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], max_com=10, select: int = 0): attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res 2, res 2)) u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) images = [] for i in range(max_com): image = vh[i].reshape(res, res) image = image - image.min() image = 255 * image / image.max() image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) image = Image.fromarray(image).resize((256, 256)) image = np.array(image) images.append(image) ptp_utils.view_images(np.concatenate(images, axis=1)) from: https://github.com/google/prompt-to-prompt/blob/main/prompt-to-prompt_stable.ipynb