The official implementation of the ICLR 2024 paper entitled "Spatio-Temporal Few-Shot Learning via Diffusive Neural Network Generation".
In this project, we propose a novel framework, GPD, which performs generative pre-training on a collection of model parameters optimized with data from source cities. Our proposed approach recasts spatio-temporal graph transfer learning as pre-training a generative hypernetwork, which generates tailored model parameters guided by prompts. Our framework has the potential to revolutionize smart city applications in data-scarce environments and contribute to more sustainable and efficient urban development.
pip install -r requirements.txt
command to install all of the Python modules and packages used in this project.The data used for training and evaluation can be found in Time-Series data. After downloading the data, move them to ./Data.
For each city, we provide the following data:
Graph data
: It records the adjacency matrix of the spatiotemporal graph. Time series data
: It records the temporal sequential data for each node.We provide two time-series datasets: crowd flow (including DC, BM, man) and traffic speed (including metr-la, pems-bay, shenzhen, chengdu_m).
The details of these two data sets are as follows:
To train node-level models with the traffic dataset, run:
cd Pretrain
CUDA_VISIBLE_DEVICES=0 python main.py --taskmode task4 --model v_GWN --test_data metr-la --ifnewname 1 --aftername TrafficData
After full-trained, run Pretrain\PrepareParams\model2tensor.py to extract parameters from the trained model. And put the params-dataset in ./Data.
To train diffusion model and generate the parameters of the target city:
cd GPD
CUDA_VISIBLE_DEVICES=0 python 1Dmain.py --expIndex 140 --targetDataset metr-la --modeldim 512 --epochs 80000 --diffusionstep 500 --basemodel v_GWN --denoise Trans1
expIndex
assigns a special number to the experiment.targetDataset
specifies the target dataset, which can be selected from ['DC', 'BM', 'man', 'metr-la', 'pemes-bay', 'shenzhen', 'chengdu_m'].modeldim
specifies the hidden dim of the Transformer.epochs
specifies the number of iterations.diffusionstep
specifies the total steps of the diffusion process.basemodel
specifies the spatio-temporal graph model, which can be selected from ['v_STGCN5', 'v_GWN'].denoise model
specifies the conditioning strategies, which can be selected from ['Trans1', 'Trans2', 'Trans3', 'Trans4', 'Trans5'].
The sample result is in GPD/Output/expXX/.
To finetune the generated parameters of the target city and evaluate, run:
cd Pretrain
CUDA_VISIBLE_DEVICES=0 python main.py --taskmode task7 --model v_GWN --test_data metr-la --ifnewname 1 --aftername finetune_7days --epochs 600 --target_days 7
taskmode
'task7' means finetune after diffusion sampling.model
specifies the spatio-temporal graph model, which can be selected from ['v_STGCN5', 'v_GWN'].test_data
specifies the dataset, which can be selected from ['DC', 'BM', 'man', 'metr-la', 'pemes-bay', 'shenzhen', 'chengdu_m'].ifnewname
assign 1 to better distinguish the results of the current experiment.aftername
Use with --ifnewname 1 to give an identification name to the log file and results folder of the current experiment.epochs
specifies the number of iterations.target_days
specifies the amount of data used in finetune stage.Let me give an example of the overall instructions. If you want to set 'metr-la' as target city:
test_data
as 'PMS-Bay', 'Didi-Chengdu', and 'Didi-Shenzhen' respectively to pretrain the models of other three source cities.targetDataset
as 'metr-la'.test_dataset
as 'metr-la'.Since finetune and pretraining share the same code framework and use the same set of parameter names, this can be a little confusing and I will try to make the distinction between them in later versions of the code.
@inproceedings{
yuan2024spatiotemporal,
title={Spatio-Temporal Few-Shot Learning via Diffusive Neural Network Generation},
author={Yuan Yuan and Chenyang Shao and Jingtao Ding and Depeng Jin and Yong Li},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=QyFm3D3Tzi}
}