GestaltCogTeam / BasicTS

A Fair and Scalable Time Series Forecasting Benchmark and Toolkit.
https://ieeexplore.ieee.org/document/10726722/
Apache License 2.0
716 stars 114 forks source link

关于TrainingMAE与ValidationMAE的大小 #134

Closed LenarBe closed 3 months ago

LenarBe commented 4 months ago

您好,zezhi shao,感谢提供如此完整的学习框架! 我在运行其代码的过程中遇到了如下问题,以GWNet为例,在训练的过程中很长一段时间内validationMAE的值是小于trainingMAE的,我猜测这可能和某些底层loss的计算机制有关系,由于MAE的计算中应当不会包括正则项,我十分好奇这一稳定的差异是来自于模型本身,例如对validset建模效果更好,还是由于其他原因,非常感谢!

zezhishao commented 4 months ago

您好,感谢您的意见,可否提供更多信息:在哪个数据集上跑的、模型是怎么设置的、如何复现?

LenarBe commented 4 months ago

好的,以GWNet在METR-LA上无addaptadj为例,在epoch接近45的之前好像都有这种情况,截图如下: image 对应cfg如下: import os import sys

TODO: remove it when basicts can be installed by pip

sys.path.append(os.path.abspath(file + "/../../..")) import torch from easydict import EasyDict from basicts.runners import SimpleTimeSeriesForecastingRunner from basicts.data import TimeSeriesForecastingDataset from basicts.losses import masked_mae from basicts.utils import load_adj

from .arch import GraphWaveNet

CFG = EasyDict()

================= general =================

CFG.DESCRIPTION = "Graph WaveNet model configuration" CFG.RUNNER = SimpleTimeSeriesForecastingRunner CFG.DATASET_CLS = TimeSeriesForecastingDataset CFG.DATASET_NAME = "METR-LA" CFG.DATASET_TYPE = "Traffic speed" CFG.DATASET_INPUT_LEN = 12 CFG.DATASET_OUTPUT_LEN = 12 CFG.GPU_NUM = 1 CFG.NULL_VAL = 0.0

================= environment =================

CFG.ENV = EasyDict() CFG.ENV.SEED = 1 CFG.ENV.CUDNN = EasyDict() CFG.ENV.CUDNN.ENABLED = True

================= model =================

CFG.MODEL = EasyDict() CFG.MODEL.NAME = "GWNet" CFG.MODEL.ARCH = GraphWaveNet adjmx, = load_adj("datasets/" + CFG.DATASET_NAME + "/adj_mx.pkl", "doubletransition") CFG.MODEL.PARAM = { "num_nodes": 207, "supports": [torch.tensor(i) for i in adj_mx], "dropout": 0.3, "gcn_bool": True, "addaptadj": False, "aptinit": None, "in_dim": 2, "out_dim": 12, "residual_channels": 32, "dilation_channels": 32, "skip_channels": 256, "end_channels": 512, "kernel_size": 2, "blocks": 4, "layers": 2 } CFG.MODEL.FORWARD_FEATURES = [0, 1] CFG.MODEL.TARGET_FEATURES = [0]

================= optim =================

CFG.TRAIN = EasyDict() CFG.TRAIN.LOSS = masked_mae CFG.TRAIN.OPTIM = EasyDict() CFG.TRAIN.OPTIM.TYPE = "Adam" CFG.TRAIN.OPTIM.PARAM = { "lr": 0.002, "weight_decay": 0.0001, } CFG.TRAIN.LR_SCHEDULER = EasyDict() CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" CFG.TRAIN.LR_SCHEDULER.PARAM = { "milestones": [1, 50], "gamma": 0.5 }

================= train =================

CFG.TRAIN.CLIP_GRAD_PARAM = { "max_norm": 5.0 } CFG.TRAIN.NUM_EPOCHS = 100 CFG.TRAIN.CKPT_SAVEDIR = os.path.join( "checkpoints", "".join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) )

train data

CFG.TRAIN.DATA = EasyDict()

read data

CFG.TRAIN.DATA.DIR = "datasets/" + CFG.DATASET_NAME

dataloader args, optional

CFG.TRAIN.DATA.BATCH_SIZE = 64 CFG.TRAIN.DATA.PREFETCH = False CFG.TRAIN.DATA.SHUFFLE = True CFG.TRAIN.DATA.NUM_WORKERS = 2 CFG.TRAIN.DATA.PIN_MEMORY = False

================= validate =================

CFG.VAL = EasyDict() CFG.VAL.INTERVAL = 1

validating data

CFG.VAL.DATA = EasyDict()

read data

CFG.VAL.DATA.DIR = "datasets/" + CFG.DATASET_NAME

dataloader args, optional

CFG.VAL.DATA.BATCH_SIZE = 64 CFG.VAL.DATA.PREFETCH = False CFG.VAL.DATA.SHUFFLE = False CFG.VAL.DATA.NUM_WORKERS = 2 CFG.VAL.DATA.PIN_MEMORY = False

================= test =================

CFG.TEST = EasyDict() CFG.TEST.INTERVAL = 1

test data

CFG.TEST.DATA = EasyDict()

read data

CFG.TEST.DATA.DIR = "datasets/" + CFG.DATASET_NAME

dataloader args, optional

CFG.TEST.DATA.BATCH_SIZE = 64 CFG.TEST.DATA.PREFETCH = False CFG.TEST.DATA.SHUFFLE = False CFG.TEST.DATA.NUM_WORKERS = 2 CFG.TEST.DATA.PIN_MEMORY = False

================= evaluate =================

CFG.EVAL = EasyDict() CFG.EVAL.HORIZONS = [3, 6, 12]

谢谢!

LenarBe commented 4 months ago

不太清楚是我理解有误还是设置问题,图片上Training的MAE2.98确实是要大于Val的2.87的,是否这两个计算MAE的方式不一样?比如取了不同的horizon之类的?

zezhishao commented 4 months ago

我看了一下其他方法的训练日志,似乎 METR-LA 这个数据集在所有方法上都是这个样子的,我不确定这是不是 bug。

finleywang commented 3 months ago

Dear LenarBe, thank you for your attention,if BasicTS helped you, please cite this paper in your fancy works, best wishes:

[1] Shao Z, Wang F, Xu Y, et al. Exploring progress in multivariate time series forecasting: Comprehensive benchmarking and heterogeneity analysis[J]. arXiv preprint arXiv:2310.06119, 2023.

@misc{shao2023exploringprogressmultivariatetime, title={Exploring Progress in Multivariate Time Series Forecasting: Comprehensive Benchmarking and Heterogeneity Analysis}, author={Zezhi Shao and Fei Wang and Yongjun Xu and Wei Wei and Chengqing Yu and Zhao Zhang and Di Yao and Guangyin Jin and Xin Cao and Gao Cong and Christian S. Jensen and Xueqi Cheng}, year={2023}, eprint={2310.06119}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2310.06119}, }