This is a Pytorch implementation of ST-SSL in the following paper:
27/10/2023: This paper is picked up by leading WeChat official accounts in the field of data mining and transportation. 当交通遇上机器学习 | 时空实验室 | AI蜗牛车
22/04/2023: The post of this paper is selected for a headline tweet by PaperWeekly and received nearly 7,000 reads. PaperWeekly is a leading AI academic platform in China.
09/02/2023: The video replay of academic presentation at AAAI 2023.
04/02/2023: J. Ji is invited to give a talk at AAAI 2023 Beijing Pre-Conference. The talk is about Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction.
We build this project by Python 3.8 with the following packages:
numpy==1.21.2
pandas==1.3.5
PyYAML==6.0
torch==1.10.1
The datasets range from {NYCBike1, NYCBike2, NYCTaxi, BJTaxi}
. You can download them from GitHub repo, Beihang Cloud Drive, or Google Drive.
Each dataset is composed of 4 files, namely train.npz
, val.npz
, test.npz
, and adj_mx.npz
.
|----NYCBike1\
| |----train.npz # training data
| |----adj_mx.npz # predefined graph structure
| |----test.npz # test data
| |----val.npz # validation data
The train/val/test
data is composed of 4 numpy.ndarray
objects:
X
: input data. It is a 4D tensor of shape (#samples, #lookback_window, #nodes, #flow_types)
, where #
denotes the number sign. Y
: data to be predicted. It is a 4D tensor of shape (#samples, #predict_horizon, #nodes, #flow_types)
. Note that X
and Y
are paired in the sample dimension. For instance, (X_i, Y_i)
is the i
-the data sample with i
indexing the sample dimension.X_offset
: a list indicating offsets of X
's lookback window relative to the current time with offset 0
. Y_offset
: a list indicating offsets of Y
's prediction horizon relative to the current time with offset 0
.For all datasets, previous 2-hour flows as well as previous 3-day flows around the predicted time are used to forecast flows for the next time step.
adj_mx.npz
is the graph adjacency matrix that indicates the spatial relation of every two regions/nodes in the studied area.
⚠️ Note that all datasets are processed as a sliding window view. Raw data of NYCBike1 and BJTaxi are collected from STResNet. Raw data of NYCBike2 and NYCTaxi are collected from STDN. If needed, one can download the original datasets from this link.
If the environment is ready, please run the following commands to train the model on the specific dataset from {NYCBike1, NYCBike2, NYCTaxi, BJTaxi}
.
>> cd ST-SSL
>> ./runme 0 NYCBike1 # 0 specifies the GPU id, NYCBike1 gives the dataset
Note that this repo only contains the NYCBike1 data because including all datasets can make this repo heavy.
If you find the paper useful, please cite the following:
@article{ji2023spatio,
title={Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction},
author={Ji, Jiahao and Wang, Jingyuan and Huang, Chao and Wu, Junjie and Xu, Boren and Wu, Zhenhe and Zhang Junbo and Zheng, Yu},
journal={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={37},
number={4},
pages={4356-4364},
year={2023}
}