zhouhaoyi / Informer2020

The GitHub repository for the paper "Informer" accepted by AAAI 2021.
Apache License 2.0
5.27k stars 1.1k forks source link

作者您好,我感觉您的代码里有一行可能写错了? #582

Closed taoge666666 closed 10 months ago

taoge666666 commented 10 months ago

在对训练集的encoder里做ProbAttention过程中有这样一行代码:

M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)

Q_K_sample的形状是(32,8,96,25)

我猜测M是对取25个k值之后对QK的点积求最大值-这25个qk点积的平均值

那么,torch.div(Q_K_sample.sum(-1), L_K) 中 L_K=96

是否应该除以sample_k=25才是对最后一个维度求平均吗?

是否要把代码更改为M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), sample_k)