thuml / Time-Series-Library

A Library for Advanced Deep Time Series Models.
MIT License
6.41k stars 1.02k forks source link

augmentation.py 的Index out of range #510

Closed Hawk0321 closed 1 month ago

Hawk0321 commented 1 month ago

augementation.py的操作有的对x.shape[2]做处理,但是dataloader.py的data_x是二维的:

class Dataset_Custom(Dataset):
        self.data_x = data[border1:border2]
        if self.set_type == 0 and self.args.augmentation_ratio > 0:
            self.data_x, self.data_y, augmentation_tags = run_augmentation_single(self.data_x, self.data_y, self.args)

我不知道该怎么处理,求助

DigitalLifeYZQiu commented 1 month ago

感谢您提出这个非常有价值的issue!

数据增广函数中出现针对第三维的操作的根本原因在于,这些算法最初是用于处理训练过程中各个batch中数据的增广任务。在时序预测任务下,输入增广函数的数据维度为二维:(数据集长度,变量数),切分为batch后数据维度提升为三维:(batch大小,序列长度,变量数)。您在数据增广函数中看到的x,其实指的是在batch维度下的三维数据。从实现上来讲,可以将数据增广操作加在exp/exp_long_term_forecasting.py文件中训练部分遍历train_loader的循环中。这里是一个例子:

for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
    if self.args.augmentation_ratio > 0 and self.args.augmentation_ratio > 0:
        batch_x, batch_y, augmentation_tags = run_augmentation_single(batch_x,batch_y,self.args)
    ......(other training steps)......

正如现版本代码中所述,数据增广操作被放到了class Dataset_Custom__read_data__函数中,即对整个数据集进行增广。在这种情况下,这里的可以将输入数据视为一个“巨大的batch”。具体实现上,可以将二维:(数据集长度,变量数)变为三维(1,数据集长度,变量数)。这里的1可以认为对应于前文提到的batch大小的概念,即将整个数据集当作一个大batch。

总的来说,当前数据增广代码的数据输入是有问题的,非常感谢您指出了这个bug。目前仓库中的代码文件和我本地存储的代码文件有一些出入,似乎是提交过程中遗漏了一些内容。对于这一疏漏给您造成的困惑,我们深感抱歉!在后续版本中我们将修复这一问题。

Hawk0321 commented 1 month ago

非常感谢