Closed turmony closed 3 months ago
你好,dim指attention作用的维度。 传入数据维度为 (B, T, N, H),dim=1指在T维上做(时间),dim=2指在N维上做(空间)。
请参考定义: https://github.com/XDZhelheim/STAEformer/blob/2dfb9e35c2f04bbb7136657100ea3d8afa5fc4e5/model/STAEformer.py#L6-L13 https://github.com/XDZhelheim/STAEformer/blob/2dfb9e35c2f04bbb7136657100ea3d8afa5fc4e5/model/STAEformer.py#L94-L98
非常感谢您的工作,我有一点不明白,时间注意力层传入的dim=1和空间注意力层传入的dim=2的作用是什么?
for attn in self.attn_layers_t: x = attn(x, dim=1) for attn in self.attn_layers_s: x = attn(x, dim=2)