chaoshangcs / GTS

Discrete Graph Structure Learning for Forecasting Multiple Time Series, ICLR 2021.
Apache License 2.0
171 stars 30 forks source link

DataLoader in function of load_dataset in utils.py #22

Open fuyuyuputao opened 2 years ago

fuyuyuputao commented 2 years ago
data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True)
data['val_loader'] = DataLoader(data['x_val'], data['y_val'], test_batch_size, shuffle=False)
data['test_loader'] = DataLoader(data['x_test'], data['y_test'], test_batch_size, shuffle=False)

Hi! Maybe there is a very small error. I think the second row about val_loader is "val_batch_size" not "test_batch_size" even if these two size are equal 64.

chaoshangcs commented 2 years ago

Hi, thanks for your message. Here we only used one variable "test_batch_size" for both the valid dataloader and test dataloader. You provided a great suggestion. We could use two variables here for more clarification. Thanks!