Open kadattack opened 2 years ago
我也是一样的问题,请问该怎么解决
我从https://zenodo.org/record/4263971下载了pems-bay.h5 文件并使用https://github.com/liyaguang/DCRNN/blob/master/scripts/generate_training_data.py生成
test.npz
,,文件./data/pems_bay/.train.npz``val.npz
当我从 README.md 运行命令时
python train.py spacetimeformer pems-bay --batch_size 32 --warmup_steps 1000 --d_model 200 --d_ff 700 --enc_layers 5 --dec_layers 6 --dropout_emb .1 --dropout_ff .3 --run_name pems-bay-spatiotemporal --base_lr 1e-3 --l2_coeff 1e-3 --loss mae --data_path ./data/pems_bay/ --d_qk 30 --d_v 30 --n_heads 10 --patience 10 --decay_factor .8
我收到以下错误
回溯(最近一次通话最后一次): 文件“/spacetimeformer-main/spacetimeformer/train.py”,第 854 行,在 main(args) 文件“/spacetimeformer-main/spacetimeformer/train.py”,第 758 行,在 main 中 ) = create_dset(args) 文件“/spacetimeformer-main/spacetimeformer/train.py”,第 394 行,在 create_dset 数据 = stf.data.metr_la.METR_LA_Data(config.data_path) 文件“/spacetimeformer-main/spacetimeformer/data/metr_la /metr_la.py”,第 43 行,在init x_c_train, y_c_train = self._split_set(context_train) 文件“/spacetimeformer-main/spacetimeformer/data/metr_la/metr_la.py”,第 21 行,在 _split_set time = 2.0 * x[ :, :, 0] - 1.0 IndexError:数组索引太多:数组是二维的,但有 3 个被索引
请问你解决了没
Set add_day_in_week to True in https://github.com/liyaguang/DCRNN/blob/master/scripts/generate_training_data.py#L72 This code expects those features to be present.
I downloaded the pems-bay.h5 file from https://zenodo.org/record/4263971 and used https://github.com/liyaguang/DCRNN/blob/master/scripts/generate_training_data.py to generate
test.npz
,train.npz
,val.npz
files in ./data/pems_bay/.When I run the command from the README.md
python train.py spacetimeformer pems-bay --batch_size 32 --warmup_steps 1000 --d_model 200 --d_ff 700 --enc_layers 5 --dec_layers 6 --dropout_emb .1 --dropout_ff .3 --run_name pems-bay-spatiotemporal --base_lr 1e-3 --l2_coeff 1e-3 --loss mae --data_path ./data/pems_bay/ --d_qk 30 --d_v 30 --n_heads 10 --patience 10 --decay_factor .8
I get the following error