Open mt324010 opened 4 years ago
对,是这样的,应该有点问题
Q = self.fc_Q(x) #[b,s,head * d_head]
K = self.fc_K(x)
V = self.fc_V(x)
Q = Q.view(batch_size * self.num_head, -1, self.dim_head)
K = K.view(batch_size * self.num_head, -1, self.dim_head)
V = V.view(batch_size * self.num_head, -1, self.dim_head)
在view Q,K,V的时候就有点问题吧,应该是先把最后一个维度split开来,再交换维度,把head弄到第二维,后续再对每一个head做atttention. 这里的view我试了下顺序会乱掉啊
各位是怎么修改的啊?我也在用,而且他的位置嵌入是二维的,我的输入是三维的。
在multi-head attention 那儿求出attention之后的view()似乎会让顺序错误。 我自己试了一试, 假设h=3, batch =1, 句子长度4, 词向量5。