Open wzy9191 opened 8 months ago
您好, 我想这个问题可以拆分为两个部分回答:
如何从GAGA的代码中获取attention scores?
GAGA的开源代码中提供了EarlyStopper,训练过程中会自动保存最优的模型文件。若按照README.md中的命令执行GAGA,那么您可以在logs/checkpoints目录下找到“earlystop{DATETIME}.pth”模型文件。
加载上述模型,执行一次全量推理,分别统计正、负样本的average attention scores。由于gaga_pytorch源码直接调用torch内置的nn.TransformerEncoder,因此在推理的过程中,可以注册“PyTorch Hook“获得nn.TransformerEncoder.self_attn,此即绘制Figure3所需的原始数据。
保存attention scores。
B.T.W (了解hook可以忽略) 注册hook可以参考官方文档,基于gaga_pytorch源码的注册代码示例如下,具体可以看下modules/目录下的两个文件。
# 注册钩子
for mod in model.transformer_encoder.layers:
mod.self_attn.register_forward_hook(atten_visualization_hook)
如何将attention scores可视化为论文中Figure 3所呈现的形式?
基于第一点回答,获得的attention score矩阵大小为21✖21,我们只需要目标节点的attention scores(参考Figure 2中的目标节点位置),因此按照step=7取出atention score,大小为3✖21。
每一个序列的位置对应一个索引,分别统计“group、relation、hop”分组attention值。Figure 3中的数值经过了归一化、并保留两位小数。
你好,有几个问题想请教一下
1、通过hook获取self_attn,我得到的attention_scores为(batch_size,sequence_length, sequence_length),查阅发现可能应该为(batch_size, num_heads, sequence_length, sequence_length),我得到的没有num_heads,以下是我的代码,不知道是否哪里出了问题。
注册Hook以获取注意力得分 attention_scores = [] def hook_fn(module, input, output): attention_scores.append(output[1]) hook_handles = [] 遍历所有的nn.TransformerEncoderLayer并注册钩子 for layer in model.transformer_encoder.layers: hook_handles.append(layer.self_attn.register_forward_hook(hook_fn)) 执行全量推理,获取attention scores with torch.no_grad(): batch_seq, batch_labels = next(iter(data_loader)) # 获取整个数据集的样本 batch_seq = batch_seq.to(device) batch_logits, diff = model(batch_seq, batch_labels) 将 attention_scores 中的每个张量移动到 CPU 上,并转换为 NumPy 数组 attention_scores_np = [tensor.detach().cpu().numpy() for tensor in attention_scores] 使用 concatenate 函数将所有数组连接成一个大数组 attention_scores_np = np.concatenate(attention_scores_np, axis=0)
2、有两层transformer_encoder.layers,获取self_attn是只取一层,还是将两层进行平均 3、统计正、负样本的average attention scores,这里是将所有正样本、负样本进行平均操作嘛 4、统计“group、relation、hop”分组attention值,这里的统计具体是如何进行统计,可以说说嘛。 group:不同hop和不同relation下的同一个group是进行平均嘛 relation:relation下有不同group和不同hop,就是同一relation下,是进行相加还是平均 hop:同理,同一hop下是进行相加还是平均
谢谢!
Hi @wzy9191 , 不好意思最近比较忙,才想起来还没回复🥹。
注册“钩子”后,可以拿到对应MHA模块的 attn_output_weights
,这就是我们关心的 "attention weights"。torch 内置的 TransformerEncoder 使用默认 forward 参数 average_attn_weights=True
,因此你得到的 “attention_scores 为 (batch_size, sequence_length, sequence_length)”,也就是平均后的,详见下图和这里的 源码链接。
self_attn 取自所有层(共两层)。不过我也好奇,不同层是否会有差异呢,可以尝试下。
“统计正、负样本的 average attention scores,这里是将所有正样本、负样本进行平均操作嘛?“
正、负样本分别按照对应节点数平均下。
统计“group、relation、hop”分组attention值,这里的统计具体是如何进行统计,可以说说嘛。
分别统计对应位置的 "attention weights",“位置”是指下标,可以参考 Figure 2 中的目标节点位置排布。比如,使用论文中的参数,序列长度 S=21,对应 “group、relation、hop” 的分组attention值索引为:
hop-0的下标 : [0,7,14]
hop-1的下标 : [1,2,3, 8,9,10, 15,16,17]
hop-2的下标 : [4,5,6, 11,12,13, 18,19,20]
relation-0 的下标 : [0,1,2,3,4,5,6]
relation-1 的下标 : [7,8,9,10,11,12,13]
relation-2 的下标 : [14,15,16,17,18,19,20]
group-0 的下标 : [1,4,8,11,15,18]
group-1 的下标 : [2,5,9,12,16, 19]
group-2 的下标 : [0,7,14,3,6,10,13,17,20]
可以按照上面的描述尝试下,如有问题欢迎留言,空了会回。
您好,我想问一下Figure 3:Visualization of attention scores中的attention scores是如何获得的,可以详细的说一说嘛