Open heyuanpengpku opened 1 week ago
the model.module.forward_eval(a=x[0], v=x[1], mode=args.testmode)
line in the TTA/READ.py file already return the attention matrix in the second variable. You just need to draw that one I guess. I'm not the author by the way :)
the
model.module.forward_eval(a=x[0], v=x[1], mode=args.testmode)
line in the TTA/READ.py file already return the attention matrix in the second variable. You just need to draw that one I guess. I'm not the author by the way :)
I'll give it a try! Thank you for the reminder~ Looking forward to the author's response as well.
Sorry for the delayed response.
As noted by chenmc1996, you need to gather all returned attention matrices and sum them up. For visualization, the following code is posted for your reference.
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
attn = torch.load(
"xxx/attn_matrix/tent_clean.pth"
)
print(attn["acc"])
attn = attn["attn_matrix_mean"].numpy()
print(
"A-A",
attn[:512, :512].mean() * 1e4,
"V-A",
attn[:512, 512:].mean() * 1e4,
"A-V",
attn[512:, :512].mean() * 1e4,
"V-V",
attn[512:, 512:].mean() * 1e4,
)
# exit()
norm = matplotlib.colors.Normalize(vmin=0, vmax=0.025)
fig, ax = plt.subplots(figsize=(7, 7))
cax = ax.matshow(
attn, cmap=cm.get_cmap("hot"), interpolation="nearest", norm=norm
) # reds, coolwarm
ax.grid(True)
plt.axis("off")
# fig.colorbar(cax)
cbar = plt.colorbar(cax)
cbar.set_clim(0, 0.025)
plt.savefig("./figures/xxx.png")
the
model.module.forward_eval(a=x[0], v=x[1], mode=args.testmode)
line in the TTA/READ.py file already return the attention matrix in the second variable. You just need to draw that one I guess. I'm not the author by the way :)
Exactly! Thanks for your prompt solution.
Great job! I wonder if you could kindly open-source the code for visualizing the attention map? Looking forward to your response! Thanks so much!