XDZhelheim / STAEformer

[CIKM'23] Official code for our paper "Spatio-Temporal Adaptive Embedding Makes Vanilla Transformer SOTA for Traffic Forecasting".
https://arxiv.org/abs/2308.10425
132 stars 16 forks source link

数据集PeMS03为什么被废弃 #5

Closed zclO closed 6 months ago

zclO commented 7 months ago

PeMS03的数据甚至要明显好于论文中的结果

XDZhelheim commented 7 months ago

因为效果不好,写论文的时候没时间继续鼓捣了。你的意思是用我们的代码在03上跑出来了SOTA结果吗?可以分享一下吗?

zclO commented 7 months ago
Trainset:   x-(15711, 12, 358, 3)   y-(15711, 12, 358, 1)
Valset:     x-(5237, 12, 358, 3)    y-(5237, 12, 358, 1)
Testset:    x-(5237, 12, 358, 3)    y-(5237, 12, 358, 1)

--------- STAEformer ---------
{
    "num_nodes": 358,
    "in_steps": 12,
    "out_steps": 12,
    "train_size": 0.6,
    "val_size": 0.2,
    "time_of_day": true,
    "day_of_week": true,
    "lr": 0.001,
    "weight_decay": 0.0005,
    "milestones": [
        15,
        30,
        40
    ],
    "lr_decay_rate": 0.1,
    "batch_size": 16,
    "max_epochs": 300,
    "early_stop": 20,
    "use_cl": false,
    "cl_step_size": 2500,
    "model_args": {
        "num_nodes": 358,
        "in_steps": 12,
        "out_steps": 12,
        "steps_per_day": 288,
        "input_dim": 3,
        "output_dim": 1,
        "input_embedding_dim": 24,
        "tod_embedding_dim": 24,
        "dow_embedding_dim": 24,
        "spatial_embedding_dim": 0,
        "adaptive_embedding_dim": 80,
        "feed_forward_dim": 256,
        "num_heads": 4,
        "num_layers": 3,
        "dropout": 0.1
    }
}
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
STAEformer                               [16, 12, 358, 1]          343,680
├─Linear: 1-1                            [16, 12, 358, 24]         96
├─Embedding: 1-2                         [16, 12, 358, 24]         6,912
├─Embedding: 1-3                         [16, 12, 358, 24]         168
├─ModuleList: 1-4                        --                        --
│    └─SelfAttentionLayer: 2-1           [16, 12, 358, 152]        --
│    │    └─AttentionLayer: 3-1          [16, 358, 12, 152]        93,024
│    │    └─Dropout: 3-2                 [16, 358, 12, 152]        --
│    │    └─LayerNorm: 3-3               [16, 358, 12, 152]        304
│    │    └─Sequential: 3-4              [16, 358, 12, 152]        78,232
│    │    └─Dropout: 3-5                 [16, 358, 12, 152]        --
│    │    └─LayerNorm: 3-6               [16, 358, 12, 152]        304
│    └─SelfAttentionLayer: 2-2           [16, 12, 358, 152]        --
│    │    └─AttentionLayer: 3-7          [16, 358, 12, 152]        93,024
│    │    └─Dropout: 3-8                 [16, 358, 12, 152]        --
│    │    └─LayerNorm: 3-9               [16, 358, 12, 152]        304
│    │    └─Sequential: 3-10             [16, 358, 12, 152]        78,232
│    │    └─Dropout: 3-11                [16, 358, 12, 152]        --
│    │    └─LayerNorm: 3-12              [16, 358, 12, 152]        304
│    └─SelfAttentionLayer: 2-3           [16, 12, 358, 152]        --
│    │    └─AttentionLayer: 3-13         [16, 358, 12, 152]        93,024
│    │    └─Dropout: 3-14                [16, 358, 12, 152]        --
│    │    └─LayerNorm: 3-15              [16, 358, 12, 152]        304
│    │    └─Sequential: 3-16             [16, 358, 12, 152]        78,232
│    │    └─Dropout: 3-17                [16, 358, 12, 152]        --
│    │    └─LayerNorm: 3-18              [16, 358, 12, 152]        304
├─ModuleList: 1-5                        --                        --
│    └─SelfAttentionLayer: 2-4           [16, 12, 358, 152]        --
│    │    └─AttentionLayer: 3-19         [16, 12, 358, 152]        93,024
│    │    └─Dropout: 3-20                [16, 12, 358, 152]        --
│    │    └─LayerNorm: 3-21              [16, 12, 358, 152]        304
│    │    └─Sequential: 3-22             [16, 12, 358, 152]        78,232
│    │    └─Dropout: 3-23                [16, 12, 358, 152]        --
│    │    └─LayerNorm: 3-24              [16, 12, 358, 152]        304
│    └─SelfAttentionLayer: 2-5           [16, 12, 358, 152]        --
│    │    └─AttentionLayer: 3-25         [16, 12, 358, 152]        93,024
│    │    └─Dropout: 3-26                [16, 12, 358, 152]        --
│    │    └─LayerNorm: 3-27              [16, 12, 358, 152]        304
│    │    └─Sequential: 3-28             [16, 12, 358, 152]        78,232
│    │    └─Dropout: 3-29                [16, 12, 358, 152]        --
│    │    └─LayerNorm: 3-30              [16, 12, 358, 152]        304
│    └─SelfAttentionLayer: 2-6           [16, 12, 358, 152]        --
│    │    └─AttentionLayer: 3-31         [16, 12, 358, 152]        93,024
│    │    └─Dropout: 3-32                [16, 12, 358, 152]        --
│    │    └─LayerNorm: 3-33              [16, 12, 358, 152]        304
│    │    └─Sequential: 3-34             [16, 12, 358, 152]        78,232
│    │    └─Dropout: 3-35                [16, 12, 358, 152]        --
│    │    └─LayerNorm: 3-36              [16, 12, 358, 152]        304
├─Linear: 1-6                            [16, 358, 12]             21,900
==========================================================================================
Total params: 1,403,940
Trainable params: 1,403,940
Non-trainable params: 0
Total mult-adds (M): 16.96
==========================================================================================
Input size (MB): 0.82
Forward/backward pass size (MB): 4395.25
Params size (MB): 4.24
Estimated Total Size (MB): 4400.32
==========================================================================================

Loss: HuberLoss

2024-03-27 16:50:23.368405 Epoch 1      Train Loss = 23.50209 Val Loss = 16.61394
2024-03-27 16:53:10.425300 Epoch 2      Train Loss = 16.73434 Val Loss = 16.90709
2024-03-27 16:55:57.261686 Epoch 3      Train Loss = 15.57980 Val Loss = 15.09482
2024-03-27 16:58:43.915353 Epoch 4      Train Loss = 14.87770 Val Loss = 14.41593
2024-03-27 17:01:30.592427 Epoch 5      Train Loss = 14.39140 Val Loss = 14.17477
2024-03-27 17:04:17.365719 Epoch 6      Train Loss = 14.03890 Val Loss = 14.76123
2024-03-27 17:07:04.823091 Epoch 7      Train Loss = 13.70546 Val Loss = 14.78409
2024-03-27 17:09:51.538088 Epoch 8      Train Loss = 13.50300 Val Loss = 13.77077
2024-03-27 17:12:38.372563 Epoch 9      Train Loss = 13.31395 Val Loss = 15.46241
2024-03-27 17:15:25.529689 Epoch 10     Train Loss = 13.22073 Val Loss = 14.16203
2024-03-27 17:18:12.193547 Epoch 11     Train Loss = 13.06612 Val Loss = 13.55851
2024-03-27 17:21:00.579256 Epoch 12     Train Loss = 12.95349 Val Loss = 13.48954
2024-03-27 17:23:47.926544 Epoch 13     Train Loss = 12.84676 Val Loss = 13.45027
2024-03-27 17:26:35.020442 Epoch 14     Train Loss = 12.77449 Val Loss = 13.33755
2024-03-27 17:29:21.465212 Epoch 15     Train Loss = 12.72734 Val Loss = 13.45078
2024-03-27 17:32:08.188061 Epoch 16     Train Loss = 12.11283 Val Loss = 13.07718
2024-03-27 17:34:55.254284 Epoch 17     Train Loss = 12.04862 Val Loss = 13.07026
2024-03-27 17:37:43.961799 Epoch 18     Train Loss = 12.02399 Val Loss = 13.09079
2024-03-27 17:40:30.422492 Epoch 19     Train Loss = 12.00090 Val Loss = 13.03485
2024-03-27 17:43:16.776650 Epoch 20     Train Loss = 11.98208 Val Loss = 13.02317
2024-03-27 17:46:03.651230 Epoch 21     Train Loss = 11.95939 Val Loss = 13.10526
2024-03-27 17:48:50.345555 Epoch 22     Train Loss = 11.94211 Val Loss = 13.07341
2024-03-27 17:51:38.024626 Epoch 23     Train Loss = 11.92362 Val Loss = 13.02579
2024-03-27 17:54:26.089619 Epoch 24     Train Loss = 11.91059 Val Loss = 13.00774
2024-03-27 17:57:13.488255 Epoch 25     Train Loss = 11.89163 Val Loss = 13.02892
2024-03-27 18:00:00.314956 Epoch 26     Train Loss = 11.87665 Val Loss = 13.01544
2024-03-27 18:02:47.460870 Epoch 27     Train Loss = 11.86163 Val Loss = 13.03814
2024-03-27 18:05:34.653359 Epoch 28     Train Loss = 11.84811 Val Loss = 13.03936
2024-03-27 18:08:21.507201 Epoch 29     Train Loss = 11.83152 Val Loss = 13.08196
2024-03-27 18:11:08.080240 Epoch 30     Train Loss = 11.81776 Val Loss = 13.01643
2024-03-27 18:13:54.850885 Epoch 31     Train Loss = 11.73884 Val Loss = 13.02487
2024-03-27 18:16:41.434426 Epoch 32     Train Loss = 11.73005 Val Loss = 13.03038
2024-03-27 18:19:27.773940 Epoch 33     Train Loss = 11.72576 Val Loss = 13.00900
2024-03-27 18:22:14.280459 Epoch 34     Train Loss = 11.72365 Val Loss = 13.01421
2024-03-27 18:25:00.732397 Epoch 35     Train Loss = 11.72125 Val Loss = 13.04520
2024-03-27 18:27:48.083849 Epoch 36     Train Loss = 11.71866 Val Loss = 13.03444
2024-03-27 18:30:34.682686 Epoch 37     Train Loss = 11.71581 Val Loss = 13.03205
2024-03-27 18:33:21.674518 Epoch 38     Train Loss = 11.71344 Val Loss = 13.04403
2024-03-27 18:36:08.148908 Epoch 39     Train Loss = 11.71141 Val Loss = 13.05678
2024-03-27 18:38:54.643207 Epoch 40     Train Loss = 11.70965 Val Loss = 13.03150
2024-03-27 18:41:41.246057 Epoch 41     Train Loss = 11.69990 Val Loss = 13.03278
2024-03-27 18:44:27.672972 Epoch 42     Train Loss = 11.69861 Val Loss = 13.03245
2024-03-27 18:47:14.541890 Epoch 43     Train Loss = 11.69850 Val Loss = 13.03871
2024-03-27 18:50:01.627914 Epoch 44     Train Loss = 11.69951 Val Loss = 13.03122
Early stopping at epoch: 44
Best at epoch 24:
Train Loss = 11.91059
Train RMSE = 20.34306, MAE = 12.25893, MAPE = 11.26089
Val Loss = 13.00774
Val RMSE = 21.89514, MAE = 13.53552, MAPE = 12.73677
Saved Model: ../saved_models/STAEformer-PEMS03-2024-03-27-16-47-26.pt
--------- Test ---------
All Steps RMSE = 26.11140, MAE = 14.99877, MAPE = 15.21329
Step 1 RMSE = 20.96294, MAE = 12.51260, MAPE = 13.17414
Step 2 RMSE = 22.63771, MAE = 13.18854, MAPE = 13.70140
Step 3 RMSE = 23.83191, MAE = 13.74374, MAPE = 14.14419
Step 4 RMSE = 24.74219, MAE = 14.18931, MAPE = 14.44330
Step 5 RMSE = 25.50163, MAE = 14.58557, MAPE = 14.79454
Step 6 RMSE = 26.13413, MAE = 14.95367, MAPE = 15.06136
Step 7 RMSE = 26.73123, MAE = 15.30649, MAPE = 15.40546
Step 8 RMSE = 27.29975, MAE = 15.66709, MAPE = 15.79337
Step 9 RMSE = 27.79695, MAE = 15.97828, MAPE = 16.07370
Step 10 RMSE = 28.29644, MAE = 16.30727, MAPE = 16.47861
Step 11 RMSE = 28.80173, MAE = 16.62232, MAPE = 16.53350
Step 12 RMSE = 29.21557, MAE = 16.93027, MAPE = 16.95568
Inference time: 16.42 s
XDZhelheim commented 7 months ago

虽然确实比我们跑出来的还略好一点,但是跟baseline比还是有不小差距。

Suasy commented 7 months ago

我发现种子对结果的影响挺大 我每次跑出来的都有较明显的差距

XDZhelheim commented 7 months ago

我发现种子对结果的影响挺大 我每次跑出来的都有较明显的差距

玄学,不用太在意,我跑别的模型也经常这样,有上下浮动很正常。我整理这个库的时候也是随手一跑,结果比我们paper里都好。。

zclO commented 7 months ago

请问METRLA数据集in_steps=12对应的是horizon 12吗,这个结果也比你们论文的结果好很多,随便问下horizon 3 的实验要怎么设置

XDZhelheim commented 7 months ago

请问METRLA数据集in_steps=12对应的是horizon 12吗,这个结果也比你们论文的结果好很多,随便问下horizon 3 的实验要怎么设置

in_steps指的是输入数据的步数(历史12步or1小时),out_steps是预测的步数(未来12步or1小时)。比论文好那是好事啊,说明我们标的反而保守了。

你说horizon 3如果指的是输入数据用3步的话,目前这个代码不支持此功能,因为这几个数据集只有12-12是公认标准。

zclO commented 7 months ago

那么论文中给出了3步和6步的实验结果是如何得到的呢image

XDZhelheim commented 7 months ago

看来我可能得给你解释一下交通预测的一些基础概念了。。

这几个数据集的task是多步预测(multi-step forecasting, 对应反义词是单步预测single-step forecasting)。也就是给12步的输入,模型一口气告诉你未来12步都是多少,所以你看到我们的训练log里最后有12行,step1到step12,前面一行的all steps指的就是这12步算作一个整体的指标。对于LA和BAY,我们单拎出来3,6,12这三步,也就是log里的对应这三行单独提出来。对于PEMS,我们比的是均值,也就是log里all steps这一行,后面12行的东西你可以忽略。

一个反例,如果我没记错的话,STGCN的原版代码就是单步预测,也就是说如果你想看3,6,12,你得把他的代码整个跑三遍,每一遍的输出只有你指定的某一步。

zclO commented 7 months ago

看来我可能得给你解释一下交通预测的一些基础概念了。。

这几个数据集的task是多步预测(multi-step forecasting, 对应反义词是单步预测single-step forecasting)。也就是给12步的输入,模型一口气告诉你未来12步都是多少,所以你看到我们的训练log里最后有12行,step1到step12,前面一行的all steps指的就是这12步算作一个整体的指标。对于LA和BAY,我们单拎出来3,6,12这三步,也就是log里的对应这三行单独提出来。对于PEMS,我们比的是均值,也就是log里all steps这一行,后面12行的东西你可以忽略。

一个反例,如果我没记错的话,STGCN的原版代码就是单步预测,也就是说如果你想看3,6,12,你得把他的代码整个跑三遍,每一遍的输出只有你指定的某一步。

好的,我误以为这里的Horizon 12 也是取12步的平均,那复现的结果就和论文基本一致,感谢你的解答,非常好的工作