FangShancheng / ABINet

Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Text Recognition
Other
420 stars 72 forks source link

问一下Parallel Attention和Position Attention的区别 #5

Open GarrettLee opened 3 years ago

GarrettLee commented 3 years ago

读了论文,我的理解文中的Position Attention如下:

# feature_map: (HW, channels)特征图
# w: (len_seq, channels)位置编码
q = w  # (len_seq, channels) 实际上位置编码就是一个可训练参数,基本上就是个fc层
k = unet(feature_map)  # (HW ,  channels)
v = feature_map  # (HW ,  channels)

attention_map = softmax(matmul(q, k.transpose()))   # (len_seq,  HW)
output = matmul(attention_map, v)   # (len_seq,  channels)

而Parallel Attention (2D Attentional Irregular Scene Text Recognizer)中是这样的:

# 参考https://arxiv.org/pdf/1906.05708.pdf公式7
# w1: (channels, channels)
# w2: (len_seq, channels)
q = w2  # (len_seq, channels)
k = tanh(w1,feature_map.transpose())  #  (channels, HW)
v = feature_map  # (HW, channels)

attention_map = softmax(matmul(q, k))  # (len_seq, HW)
output = matmul(attention_map, v)  # (len_seq, channels)

所以能否理解为Parallel Attention和Position Attention的区别只是在于对k的变换不一样?

FangShancheng commented 3 years ago

您理解的大体上是没问题的,论文中视觉模型用到的Position Attention是在现有的attention方法上的reformulate,尤其是强调,attention中对q, k, v的进一步抽象会影响模型的性能。 实现上,跟Parallel Attention等其他attention的区别有:

  1. q使用的是不可学习的position encoding参数,其通过fc层进行投影
  2. k使用unet进行进一步抽象,这一点是有效的。
GarrettLee commented 3 years ago

您理解的大体上是没问题的,论文中视觉模型用到的Position Attention是在现有的attention方法上的reformulate,尤其是强调,attention中对q, k, v的进一步抽象会影响模型的性能。

实现上,跟Parallel Attention等其他attention的区别有:

  1. q使用的是不可学习的position encoding参数,其通过fc层进行投影

  2. k使用unet进行进一步抽象,这一点是有效的。

我上面那段代码里面两个方法的q基本是一样的吧,还是我理解还有误吗

GarrettLee commented 3 years ago

Unet看起来确实可能会有效,我们也会试试看

FangShancheng commented 3 years ago

您理解的大体上是没问题的,论文中视觉模型用到的Position Attention是在现有的attention方法上的reformulate,尤其是强调,attention中对q, k, v的进一步抽象会影响模型的性能。 实现上,跟Parallel Attention等其他attention的区别有:

  1. q使用的是不可学习的position encoding参数,其通过fc层进行投影
  2. k使用unet进行进一步抽象,这一点是有效的。

我上面那段代码里面两个方法的q基本是一样的吧,还是我理解还有误吗

q这里,逻辑上有点不一样,准确来说,实现上应该是q = fc(position encoding),甚至是直接的q=position encoding,其中position encoding为transformer中的实现,而不是直接的w。这个操作只要是增加可解释性。效果上,以及unet效果的其他替代方案,vision model这里,我们并没有展开太多实验,主要还是针对language model展开的实验。

GarrettLee commented 3 years ago

明白了

simplify23 commented 3 years ago

attention中对q, k, v的进一步抽象会影响模型的性能。

想问一下,这句话是否有做过一些实验来验证呢,这里的进一步主要指的是哪一类型的方法,robust scanner方法上,对key 加入了bilstm+cnn的增强,看起来也是有效的。能不能进一步阐述一下

FangShancheng commented 3 years ago

attention中对q, k, v的进一步抽象会影响模型的性能。

想问一下,这句话是否有做过一些实验来验证呢,这里的进一步主要指的是哪一类型的方法,robust scanner方法上,对key 加入了bilstm+cnn的增强,看起来也是有效的。能不能进一步阐述一下

我们论文表1 ablation 关于视觉模型的实验有涉及到这个点,进一步抽象的方法就是对q,k,v都加函数抽象。ABINet中k是用的UNet做抽象的。此外,对v做抽象等于加强backbone,额外的实验没发现有多大增益,对q做增强也有一定收益。

YanShuang17 commented 2 years ago

@FangShancheng 作者你好!

很难理解https://github.com/FangShancheng/ABINet/blob/main/modules/attention.pyAttention类的计算att_weight的逻辑。

这里我和SRN中视觉部分(PVAM)中的attention过程作对比: (1) SRN-PVAM中的attention过程(伪代码,假设qkv的维度都是d_model):

# e.g. d_model = 512, max_seq_len = seq_len_q = 25, vocab_size = 37
key2att = nn.Linear(d_model, d_model)
query2att = nn.Linear(d_model, d_model)
embedding = nn.Embedding(max_seq_len, d_model)
score = nn.Linear(d_model, 1)
classifier = nn.Linear(d_model, vocab_size)

# input is encoder_out
reading_order = torch.arange(max_seq_len, dtype=torch.long)
Q = embedding(reading_order)  # (max_seq_len, d_model)
K, V = encoder_out  # (batch_size, seq_len_k, d_model)

# 这里计算att_weight的过程很容易理解,和经典的attention模型比如ASTER的attention过程相同
######
att_q = key2att(Q).unsqueeze(0).unsqueeze(2)  # (1, seq_len_q, 1, d_model)
att_k = query2att(K).unsqueeze(1)  # (batch_size, 1, seq_len_k, d_model)
att_weight = score(torch.tanh(att_q + att_k)).squeeze(3)  # (batch_size, seq_len_q, seq_len_k)
######

att_weight = F.softmax(att_weight, dim=-1)
decoder_out = torch.bmm(att_weight, K)  # (batch_size, seq_len_q, d_model)
logits = classifier(decoder_out)  # (batch_size, seq_len_q, vicab_size)

(2) https://github.com/FangShancheng/ABINet/blob/main/modules/attention.pyAttention类的实现过程: (我注意到,贵课题组的VisionLAN中的attention也是这个,参考https://github.com/wangyuxin87/VisionLAN/blob/main/modules/modules.py中的PP_Layer类)

# e.g. d_model = 512, max_seq_len = seq_len_q = 25, vocab_size = 37
embedding = nn.Embedding(max_seq_len, d_model)
w0 = nn.Linear(max_seq_len, seq_len_k)
wv = nn.Linear(d_model, d_model)
we = nn.Linear(d_model, max_seq_len)
classifier = nn.Linear(d_model, vocab_size)

# input is encoder_out
K, V = encoder_out  # (batch_size, seq_len_k, d_model)
reading_order = torch.arange(max_seq_len, dtype=torch.long)

# 如何理解下面这段计算att_weight的代码?
#####
reading_order = embedding(reading_order)  # (seq_len_q, d_model)
reading_order = reading_order.unsqueeze(0).expand(K.size(0), -1)  # (batch_size, seq_len_q, d_model)
t = w0(reading_order.permute(0, 2, 1))  # (batch_size, d_model, seq_len_q) ==> (batch_size, d_model, seq_len_k)
t = torch.tanh(t.permute(0, 2, 1) + wv(K))  # (batch_size, seq_len_k, d_model)
att_weight = we(t)  # (batch_size, seq_len_k, d_model) ==> (batch_size, seq_len_k, seq_len_q)
######

att_weight = F.softmax(att_weight, dim=-1)
decoder_out = torch.bmm(att_weight, K)  # (batch_size, seq_len_q, d_model)
logits = classifier(decoder_out)  # (batch_size, seq_len_q, vicab_size)

麻烦解惑,谢谢!