QData / spacetimeformer

Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."
https://arxiv.org/abs/2109.12218
MIT License
808 stars 191 forks source link

Getting error on PEMS-BAY dataset #61

Open kadattack opened 2 years ago

kadattack commented 2 years ago

I downloaded the pems-bay.h5 file from https://zenodo.org/record/4263971 and used https://github.com/liyaguang/DCRNN/blob/master/scripts/generate_training_data.py to generate test.npz, train.npz, val.npz files in ./data/pems_bay/.

When I run the command from the README.md python train.py spacetimeformer pems-bay --batch_size 32 --warmup_steps 1000 --d_model 200 --d_ff 700 --enc_layers 5 --dec_layers 6 --dropout_emb .1 --dropout_ff .3 --run_name pems-bay-spatiotemporal --base_lr 1e-3 --l2_coeff 1e-3 --loss mae --data_path ./data/pems_bay/ --d_qk 30 --d_v 30 --n_heads 10 --patience 10 --decay_factor .8

I get the following error

Traceback (most recent call last): File "/spacetimeformer-main/spacetimeformer/train.py", line 854, in main(args) File "/spacetimeformer-main/spacetimeformer/train.py", line 758, in main ) = create_dset(args) File "/spacetimeformer-main/spacetimeformer/train.py", line 394, in create_dset data = stf.data.metr_la.METR_LA_Data(config.data_path) File "/spacetimeformer-main/spacetimeformer/data/metr_la/metr_la.py", line 43, in init x_c_train, y_c_train = self._split_set(context_train) File "/spacetimeformer-main/spacetimeformer/data/metr_la/metr_la.py", line 21, in _split_set time = 2.0 * x[:, :, 0] - 1.0 IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed

TWENTY-FOU commented 2 years ago

我也是一样的问题,请问该怎么解决

TWENTY-FOU commented 2 years ago

我从https://zenodo.org/record/4263971下载了pems-bay.h5 文件并使用https://github.com/liyaguang/DCRNN/blob/master/scripts/generate_training_data.py生成test.npz,,文件./data/pems_bay/.train.npz``val.npz

当我从 README.md 运行命令时python train.py spacetimeformer pems-bay --batch_size 32 --warmup_steps 1000 --d_model 200 --d_ff 700 --enc_layers 5 --dec_layers 6 --dropout_emb .1 --dropout_ff .3 --run_name pems-bay-spatiotemporal --base_lr 1e-3 --l2_coeff 1e-3 --loss mae --data_path ./data/pems_bay/ --d_qk 30 --d_v 30 --n_heads 10 --patience 10 --decay_factor .8

我收到以下错误

回溯(最近一次通话最后一次): 文件“/spacetimeformer-main/spacetimeformer/train.py”,第 854 行,在 main(args) 文件“/spacetimeformer-main/spacetimeformer/train.py”,第 758 行,在 main 中 ) = create_dset(args) 文件“/spacetimeformer-main/spacetimeformer/train.py”,第 394 行,在 create_dset 数据 = stf.data.metr_la.METR_LA_Data(config.data_path) 文件“/spacetimeformer-main/spacetimeformer/data/metr_la /metr_la.py”,第 43 行,在init x_c_train, y_c_train = self._split_set(context_train) 文件“/spacetimeformer-main/spacetimeformer/data/metr_la/metr_la.py”,第 21 行,在 _split_set time = 2.0 * x[ :, :, 0] - 1.0 IndexError:数组索引太多:数组是二维的,但有 3 个被索引

请问你解决了没

szdrnja commented 1 year ago

Set add_day_in_week to True in https://github.com/liyaguang/DCRNN/blob/master/scripts/generate_training_data.py#L72 This code expects those features to be present.