Closed wlufy closed 1 year ago
您好,我想生成与论文中类似的heat map,但是不知道如何获得 attention mask?想问一下您是怎么实现的? 谢谢您的回复~
你可以在SP模块这行下面增加代码:
max_v, max_indice = torch.max(soft_assign,dim=1) # (N, H * W) (N, H * W) heatmap = max_indice * max_v
使用heatmap,结合一些可视化方法(例如平滑)去绘制,这里给了一个简单的例子:
plt.imshow(heatmap[k].view(7,7).cpu().numpy(), cmap=plt.cm.viridis) # viridis
您好,我想生成与论文中类似的heat map,但是不知道如何获得 attention mask?想问一下您是怎么实现的? 谢谢您的回复~