XDZhelheim / STAEformer

[CIKM'23] Official code for our paper "Spatio-Temporal Adaptive Embedding Makes Vanilla Transformer SOTA for Traffic Forecasting".
https://arxiv.org/abs/2308.10425
132 stars 16 forks source link

attn_layers_t(x, dim=1)与attn_layers_s(x, dim=2) #9

Closed turmony closed 3 months ago

turmony commented 4 months ago

非常感谢您的工作,我有一点不明白,时间注意力层传入的dim=1和空间注意力层传入的dim=2的作用是什么? for attn in self.attn_layers_t: x = attn(x, dim=1) for attn in self.attn_layers_s: x = attn(x, dim=2)

XDZhelheim commented 4 months ago

你好,dim指attention作用的维度。 传入数据维度为 (B, T, N, H),dim=1指在T维上做(时间),dim=2指在N维上做(空间)。

请参考定义: https://github.com/XDZhelheim/STAEformer/blob/2dfb9e35c2f04bbb7136657100ea3d8afa5fc4e5/model/STAEformer.py#L6-L13 https://github.com/XDZhelheim/STAEformer/blob/2dfb9e35c2f04bbb7136657100ea3d8afa5fc4e5/model/STAEformer.py#L94-L98