XLearning-SCU / 2024-ICLR-READ

Pytorch implementation of "Test-time Adaption against Multi-modal Reliability Bias".
28 stars 0 forks source link

Could you provide the code for visualizing the attention map mentioned in the paper? #7

Open heyuanpengpku opened 1 week ago

heyuanpengpku commented 1 week ago

Great job! I wonder if you could kindly open-source the code for visualizing the attention map? Looking forward to your response! Thanks so much!

chenmc1996 commented 1 week ago

the model.module.forward_eval(a=x[0], v=x[1], mode=args.testmode) line in the TTA/READ.py file already return the attention matrix in the second variable. You just need to draw that one I guess. I'm not the author by the way :)

heyuanpengpku commented 1 week ago

the model.module.forward_eval(a=x[0], v=x[1], mode=args.testmode) line in the TTA/READ.py file already return the attention matrix in the second variable. You just need to draw that one I guess. I'm not the author by the way :)

I'll give it a try! Thank you for the reminder~ Looking forward to the author's response as well.

mouxingyang commented 1 day ago

Sorry for the delayed response.

As noted by chenmc1996, you need to gather all returned attention matrices and sum them up. For visualization, the following code is posted for your reference.

import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm

attn = torch.load(
    "xxx/attn_matrix/tent_clean.pth"
)
print(attn["acc"])
attn = attn["attn_matrix_mean"].numpy()
print(
    "A-A",
    attn[:512, :512].mean() * 1e4,
    "V-A",
    attn[:512, 512:].mean() * 1e4,
    "A-V",
    attn[512:, :512].mean() * 1e4,
    "V-V",
    attn[512:, 512:].mean() * 1e4,
)
# exit()

norm = matplotlib.colors.Normalize(vmin=0, vmax=0.025)
fig, ax = plt.subplots(figsize=(7, 7))
cax = ax.matshow(
    attn, cmap=cm.get_cmap("hot"), interpolation="nearest", norm=norm
)  # reds, coolwarm
ax.grid(True)
plt.axis("off")
# fig.colorbar(cax)
cbar = plt.colorbar(cax)
cbar.set_clim(0, 0.025)

plt.savefig("./figures/xxx.png")
mouxingyang commented 1 day ago

the model.module.forward_eval(a=x[0], v=x[1], mode=args.testmode) line in the TTA/READ.py file already return the attention matrix in the second variable. You just need to draw that one I guess. I'm not the author by the way :)

Exactly! Thanks for your prompt solution.