ant-research / Pyraformer

Apache License 2.0
252 stars 38 forks source link

Maybe a bug in the TVM inpmelentation #15

Open jlidw opened 2 years ago

jlidw commented 2 years ago

TVM里的推导有点复杂,直接讲中文了哈。我仔细读了TVM部分的代码,发现反向传播部分有一个part好像有点问题,希望能和作者讨论一下。

https://github.com/alipay/Pyraformer/blob/84af4dbd93b7b96975b5034f0dde412005260123/pyraformer/hierarchical_mm_tvm.py#L74-L77 这个部分的实现是对应的是反向传播的计算,具体有两种情况,一种是attn=Q*K中反向传播(计算K的梯度),另一种是contex=attn*V中的反向传播(计算V的梯度);从矩阵计算的形式来说,这两者是等价的。所以这里我以第一种情况为例子:

*反向计算K的梯度时,对于每个k_i,我们要反过来找跟它‘结合’过的多个queries; 所以条件判断时用q_k_mask[i, k]是不正确的,这样的话依然找的是每个q_i对应的keys。正确的写法应该是: k_q_mask[i, k]>=0, X[l, k_q_mask[i, k], q, idx] Y[l, k_q_mask[i, k], q, j].** 对应的解释是:对于每个k_i, 反向找跟它结合过的queries,也即k_q_mask[i, k];这里的X对应的attn的梯度,Y对应的是Q;由于X的维度是[batch, seq_n, heads, max_attn], 对于X的梯度,还需要进一步找到k_i在k_q_mask[i, k]这个query这一行的索引,也即idx。

我推导了一下,idx貌似不能直接使用q_k_mask和k_q_mask直接计算出来,这时我们需要一个新的index matrix来使计算变得简单。假设这个idx matrix为M,那么它应该是:M[k_q_mask[i,j], i] = torch.where(q_k_mask[k_q_mask[i,j]]=i)。这个M可以像k_q_mask那样提前计算出来,然后用于TVM的计算。然后真正的计算就变成了: *k_q_mask[i, k]>=0, X[l, k_q_mask[i, k], q, M[k_q_mask[i,k], i]] Y[l, k_q_mask[i, k], q, j]**

另外不知道作者在完成TVM的代码部分之后,有没有比较它跑出来的结果跟naive impelmentation的差异大不大,还是只比较了efficiency。

这里的计算确实很绕,以上的推导只建立在我自已的理解上,不保证一定是对的,如果有可能的话,想跟作者仔细讨论一下这里TVM的实现。非常感谢!

Zhazhan commented 2 years ago

你好,感谢你对我们工作的关注。

你实现反向传播的思路是对的,按照你的思路写出来的代码应该也能够正常运行。只不过我们代码里"k_q_mask"的含义和你想的稍有不同,才造成了“可能是bug”的误会。

先说为什么计算K的梯度时,我们仍然使用的是q_k_mask:因为PAM的注意力机制构成的图是无向图,这意味着第i个query关注了哪些key,第i个key就被哪些query关注。因此,对于每个k_i,反过来找跟它‘结合’过的多个queries,也需要看q_k_mask[i]中的元素。

也正是因为q_k_mask兼任了你的实现中"k_q_mask"的功能,所以我们代码中的k_q_mask和你理解的稍有不同,实际上和你的实现中M的作用相近。k[i, j]的含义是:对于第i个key的第j个query,k[i ,j]存储了序列中第j个点作为query时,序列中第i个点在它关注列表中的索引。

这部分因为索引众多,逻辑确实很绕,希望上面的解释能帮助到你。

另外在完成TVM的代码后我们也实际验证过它训练出的性能,和naive implementation的差别确实是不大的。

jlidw commented 2 years ago

非常感谢!

我推导的时候,自动带入了更复杂的有向图的情况,忽略了PAM是无向图这个前提条件;有向图在实现上会更复杂一点,因为“对于每个k_i, 反向找跟它结合过的queries”的时候不能再简单地用q_k_mask来进行索引了。

我再仔细检查下自己的计算细节,再次感谢!

Zhazhan commented 2 years ago

不客气,你实现的有向图的情况可以支持更灵活的attention机制,如果有兴趣单独开一个repo的话或许能帮助到更多的人

jlidw commented 2 years ago

当然,等完善好相关实现,后续会把代码开源的~