Closed Hawk0321 closed 1 month ago
感谢您提出这个非常有价值的issue!
数据增广函数中出现针对第三维的操作的根本原因在于,这些算法最初是用于处理训练过程中各个batch中数据的增广任务。在时序预测任务下,输入增广函数的数据维度为二维:(数据集长度,变量数),切分为batch后数据维度提升为三维:(batch大小,序列长度,变量数)。您在数据增广函数中看到的x,其实指的是在batch维度下的三维数据。从实现上来讲,可以将数据增广操作加在exp/exp_long_term_forecasting.py
文件中训练部分遍历train_loader
的循环中。这里是一个例子:
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
if self.args.augmentation_ratio > 0 and self.args.augmentation_ratio > 0:
batch_x, batch_y, augmentation_tags = run_augmentation_single(batch_x,batch_y,self.args)
......(other training steps)......
正如现版本代码中所述,数据增广操作被放到了class Dataset_Custom
的__read_data__
函数中,即对整个数据集进行增广。在这种情况下,这里的可以将输入数据视为一个“巨大的batch”。具体实现上,可以将二维:(数据集长度,变量数)变为三维(1,数据集长度,变量数)。这里的1可以认为对应于前文提到的batch大小的概念,即将整个数据集当作一个大batch。
总的来说,当前数据增广代码的数据输入是有问题的,非常感谢您指出了这个bug。目前仓库中的代码文件和我本地存储的代码文件有一些出入,似乎是提交过程中遗漏了一些内容。对于这一疏漏给您造成的困惑,我们深感抱歉!在后续版本中我们将修复这一问题。
非常感谢
augementation.py的操作有的对x.shape[2]做处理,但是dataloader.py的data_x是二维的:
我不知道该怎么处理,求助