XDZhelheim / STAEformer

[CIKM'23] Official code for our paper "Spatio-Temporal Adaptive Embedding Makes Vanilla Transformer SOTA for Traffic Forecasting".
https://arxiv.org/abs/2308.10425
132 stars 16 forks source link

模型嵌入层代码在哪里呢? #6

Closed serenaand closed 6 months ago

serenaand commented 7 months ago

你好,请问模型嵌入层代码没给出来吗?还是直接将处理好的数据直接放进了那几个数据集文件中吗?可以看看具体是如何实现数据嵌入以及论文最核心部分,时空自适应嵌入的代码吗?

XDZhelheim commented 7 months ago

Embedding就在模型的代码里。。应该挺显眼啊

定义: https://github.com/XDZhelheim/STAEformer/blob/2dfb9e35c2f04bbb7136657100ea3d8afa5fc4e5/model/STAEformer.py#L155-L168

其中165~168行是时空自适应嵌入。

使用: https://github.com/XDZhelheim/STAEformer/blob/2dfb9e35c2f04bbb7136657100ea3d8afa5fc4e5/model/STAEformer.py#L196-L224

spatial_embedding是文章里一个对比实验用的,忽略就行。

serenaand commented 7 months ago

那为什么这个空间嵌入维度为0呢? model_args: num_nodes: 207 in_steps: 12 out_steps: 12 steps_per_day: 288 input_dim: 3 output_dim: 1 input_embedding_dim: 24 tod_embedding_dim: 24 dow_embedding_dim: 24 spatial_embedding_dim: 0 adaptive_embedding_dim: 80 feed_forward_dim: 256 num_heads: 4 num_layers: 3 dropout: 0.1

XDZhelheim commented 7 months ago

因为我们的模型里本来就没有空间嵌入,是0就对了。为什么代码里写了空间嵌入,是因为我们paper的实验部分里面,有一个 时空自适应嵌入vs空间嵌入 的对比实验。

image

serenaand commented 7 months ago

好的,感谢作者,看到啦!

serenaand commented 7 months ago

再问一个问题,这个数据集的处理是如何做的呢?这个data.npz和index.npz

XDZhelheim commented 7 months ago

OKOK

数据集生成,首先得放人家BasicTS的原版代码 https://github.com/zezhishao/BasicTS/blob/master/scripts/data_preparation/METR-LA/generate_training_data.py

我是从BasicTS学来的,自己稍微改动了一些 https://github.com/XDZhelheim/Torch-MTS/blob/dev/scripts/generate_training_data.py

serenaand commented 7 months ago

好的 感谢感谢