649453932 / Chinese-Text-Classification-Pytorch

中文文本分类,TextCNN,TextRNN,FastText,TextRCNN,BiLSTM_Attention,DPCNN,Transformer,基于pytorch,开箱即用。
MIT License
5.27k stars 1.23k forks source link

Transformer是否有问题? #10

Open mt324010 opened 4 years ago

mt324010 commented 4 years ago

在multi-head attention 那儿求出attention之后的view()似乎会让顺序错误。 我自己试了一试, 假设h=3, batch =1, 句子长度4, 词向量5。

a = torch.randn(3,4,5) a tensor([[[-0.1241, 0.0364, 1.2337, -0.5907, 0.8305], [-0.0610, -0.9682, 0.7830, 1.5998, -0.6637], [ 0.1863, -1.2179, 0.0710, 0.6962, -0.0442], [ 0.0584, -0.5964, 0.8453, -1.3244, -0.0499]],

    [[ 2.7228,  0.6973, -1.2440,  1.8854,  2.3017],
     [-0.1034, -1.7281, -1.1495, -0.2478, -0.8541],
     [-0.2823, -0.3416, -1.3749,  0.2995, -0.1860],
     [-1.1601,  0.9876,  0.2881, -1.8866, -1.3901]],

    [[-1.1265,  1.2683, -0.7065,  0.0946,  0.3501],
     [-0.1266,  1.2834, -1.2694,  1.1730, -0.3443],
     [ 1.4679,  2.1238,  0.2405, -0.4388,  0.8566],
     [ 1.8933,  0.4461,  2.2419,  0.6118, -1.5001]]])

a.view(1, -1, 15) tensor([[[-0.1241, 0.0364, 1.2337, -0.5907, 0.8305, -0.0610, -0.9682, 0.7830, 1.5998, -0.6637, 0.1863, -1.2179, 0.0710, 0.6962, -0.0442], [ 0.0584, -0.5964, 0.8453, -1.3244, -0.0499, 2.7228, 0.6973, -1.2440, 1.8854, 2.3017, -0.1034, -1.7281, -1.1495, -0.2478, -0.8541], [-0.2823, -0.3416, -1.3749, 0.2995, -0.1860, -1.1601, 0.9876, 0.2881, -1.8866, -1.3901, -1.1265, 1.2683, -0.7065, 0.0946, 0.3501], [-0.1266, 1.2834, -1.2694, 1.1730, -0.3443, 1.4679, 2.1238, 0.2405, -0.4388, 0.8566, 1.8933, 0.4461, 2.2419, 0.6118, -1.5001]]]) 可以看到view只是按顺序拼接,并没有做到concat

wjczf123 commented 4 years ago

对,是这样的,应该有点问题

GDP-no-D commented 3 years ago
    Q = self.fc_Q(x) #[b,s,head * d_head]
    K = self.fc_K(x)
    V = self.fc_V(x)
    Q = Q.view(batch_size * self.num_head, -1, self.dim_head)
    K = K.view(batch_size * self.num_head, -1, self.dim_head)
    V = V.view(batch_size * self.num_head, -1, self.dim_head)

在view Q,K,V的时候就有点问题吧,应该是先把最后一个维度split开来,再交换维度,把head弄到第二维,后续再对每一个head做atttention. 这里的view我试了下顺序会乱掉啊

nutshell999 commented 3 years ago

各位是怎么修改的啊?我也在用,而且他的位置嵌入是二维的,我的输入是三维的。