R0oup1iao / Traffic-Transformer

This is a Pytorch implementation of Traffic Transformer. Now the corresponding paper is available online at
https://ieeexplore.ieee.org/document/9520129
23 stars 5 forks source link

LearnedPositionalEncoding #4

Open 467853140 opened 1 year ago

467853140 commented 1 year ago

关于模型里 class LearnedPositionalEncoding(nn.Embedding): def init(self,d_model, dropout = 0.1,max_len = 500): super().init(max_len, d_model) self.dropout = nn.Dropout(p = dropout)

def forward(self, x):
    weight = self.weight.data.unsqueeze(1)
    x = x + weight[:x.size(0),:]
    return self.dropout(x)

这部分起到什么作用呢?论文里好像没有相关的讲解

R0oup1iao commented 1 year ago

这是用nn.Embedding类实现了一个可学习的位置编码。

当时主要是想做到一种相对位置编码的感觉。现在关于Related Postional Encoding,以及各种魔改PE的文章很多很强,应该可以把这部分升级替换掉