zchoi / S2-Transformer

[IJCAI 2022] Official Pytorch code for paper “S2 Transformer for Image Captioning”
https://www.ijcai.org/proceedings/2022/0224.pdf
MIT License
80 stars 4 forks source link

how to get the attention mask? #11

Closed wlufy closed 1 year ago

wlufy commented 1 year ago

您好,我想生成与论文中类似的heat map,但是不知道如何获得 attention mask?想问一下您是怎么实现的? 谢谢您的回复~

zchoi commented 1 year ago

你可以在SP模块这行下面增加代码:

max_v, max_indice = torch.max(soft_assign,dim=1)  # (N, H * W) (N, H * W)
heatmap = max_indice * max_v

使用heatmap,结合一些可视化方法(例如平滑)去绘制,这里给了一个简单的例子:


plt.imshow(heatmap[k].view(7,7).cpu().numpy(), cmap=plt.cm.viridis) # viridis