yantijin / ScoreGradPred

Code for *ScoreGrad: Multivariate Probabilistic Time Series Forecasting with Continuous Energy-based Generative Models*
82 stars 12 forks source link

The hyperparameters in traffic_nips. #4

Closed hanlaoshi closed 1 year ago

hanlaoshi commented 1 year ago

Hi there,

I was wondering if you would be willing to share the hyperparameters that you used for the traffic_nips dataset in your paper? I have been trying to use the hyperparameters provided for the electricity_nips dataset, but the results that I obtained on the traffic_nips dataset were significantly different from those reported in the paper. If you could share your hyperparameters, I would greatly appreciate it.

Thank you for your time and assistance.

yantijin commented 1 year ago

Hi~, thanks for your attention! Firstly, the hyperparameter settings for different datasets are different. We find that beta1, beta2, scale, train epochs, and the dimension of feature states are very important for the forecasting performance.

In addition, we also find that the results on _trafficnips are more sensitive to hyperparameters than the other datasets in the experiments. Here are some train weights and logs of ScoreGrad on traffic. And I hope it will help. Google Drive

hanlaoshi commented 1 year ago

Hi~, thanks for your attention! Firstly, the hyperparameter settings for different datasets are different. We find that beta1, beta2, scale, train epochs, and the dimension of feature states are very important for the forecasting performance.

In addition, we also find that the results on _trafficnips are more sensitive to hyperparameters than the other datasets in the experiments. Here are some train weights and logs of ScoreGrad on traffic. And I hope it will help. Google Drive

I have been using these hyperparameters for training in subvpsde model, and while some results match those reported in the paper, others are worse. I was wondering if you could provide any insights into why this might be the case?

I have attached my code and results for reference. Any help or suggestions would be greatly appreciated.

Thank you for your time and assistance.

estimator = ScoreGradEstimator( input_size=3856, freq='H', prediction_length=dataset.metadata.prediction_length, target_dim=int(dataset.metadata.feat_static_cat[0].cardinality), num_layers=2, num_cells=128, cell_type='GRU', num_parallel_samples=config.num_parallel_samples, dropout_rate=0.1, conditioning_length=100, diff_steps=100, beta_min=0.01, beta_end=10, residual_layers=8, residual_channels=8, dilation_cycle_length=2, scaling=False, md_type='subvpsde', continuous=True, reduce_mean=True, likelihood_weighting=False, config=config, trainer=Trainer( epochs=20, batch_size=64, num_batches_per_epoch=100, learning_rate=1e-3, decay=None, device=config.device, wandb_mode='disabled', config=config) ) the results are as below: ################## the 0-th time, CRPS_sum: 0.04249960683995443 the 1-th time, CRPS_sum: 0.12992904179415346 the 2-th time, CRPS_sum: 0.06056683129822026 the 3-th time, CPRS_sum: 0.03662869992440223 ##################

hanlaoshi commented 1 year ago

Hi~, thanks for your attention! Firstly, the hyperparameter settings for different datasets are different. We find that beta1, beta2, scale, train epochs, and the dimension of feature states are very important for the forecasting performance.

In addition, we also find that the results on _trafficnips are more sensitive to hyperparameters than the other datasets in the experiments. Here are some train weights and logs of ScoreGrad on traffic. And I hope it will help. Google Drive

I wanted to take a moment to express my sincere gratitude for your impressive work and generous sharing of your well-trained model and hyperparameters. I am writing to inquire about some discrepancies that I have encountered when using the trained models that you have provided. Specifically, I have noticed differences in the results obtained from the subvpsde model, the vpsde model, and the vesde model as compared to the results reported in the paper.

For the subvpsde model, the average CRPS_sum value that I have obtained is 0.37, vpsde model is 0.063 and vesde model is 0.040.

the output of results as below: <class 'score_sde.trainer.Trainer'> md_type:{}, num_cells:{} subvpsde 128 ------the method of ScoreGrad------is:subvpsde CRPS: 0.2302250774867946 ND: 0.2651221122980262 NRMSE: 0.6391780783176729

CRPS-Sum: 0.38324099030932024 ND-Sum: 0.418940006309006 NRMSE-Sum: 0.46314387977726235 #################### CRPS: 0.22609270020021577 ND: 0.26049269979067335 NRMSE: 0.6300855898901556

CRPS-Sum: 0.37491109214372337 ND-Sum: 0.410651565873248 NRMSE-Sum: 0.45159467238735496 #################### CRPS: 0.22783096526647342 ND: 0.2615933016714219 NRMSE: 0.634868596561549

CRPS-Sum: 0.37804421871533755 ND-Sum: 0.4116224129779668 NRMSE-Sum: 0.4570603758903712 #################### CRPS: 0.2293369145540681 ND: 0.26327622332101513 NRMSE: 0.6362867128251717

CRPS-Sum: 0.3803011485287016 ND-Sum: 0.41274235519271607 NRMSE-Sum: 0.45926252045209287 #################### CRPS: 0.22853455125355113 ND: 0.26399291421867005 NRMSE: 0.6346816895335943

CRPS-Sum: 0.37761255606613836 ND-Sum: 0.4135018203405725 NRMSE-Sum: 0.45687966033078964 #################### CRPS: 0.22484617339930038 ND: 0.25881541790393187 NRMSE: 0.6283151866977141

CRPS-Sum: 0.37411721712527196 ND-Sum: 0.4085678253804295 NRMSE-Sum: 0.45002123280558826 #################### CRPS: 0.22747048312856186 ND: 0.26091447891800174 NRMSE: 0.6350557726104186

CRPS-Sum: 0.3791473036535954 ND-Sum: 0.4129624572360978 NRMSE-Sum: 0.4572322168127546 #################### CRPS: 0.22833773954919803 ND: 0.26298050769601805 NRMSE: 0.6343096974415163

CRPS-Sum: 0.3778284225351686 ND-Sum: 0.4120888925835898 NRMSE-Sum: 0.4567079932542919 #################### CRPS: 0.22682725521098776 ND: 0.2610285602321497 NRMSE: 0.6313838068697308

CRPS-Sum: 0.37649053984627884 ND-Sum: 0.41183696456561686 NRMSE-Sum: 0.45380611791268277 #################### CRPS: 0.22789534208666795 ND: 0.262091913290372 NRMSE: 0.633900837419824

CRPS-Sum: 0.3781326886016716 ND-Sum: 0.41276157828307297 NRMSE-Sum: 0.45602799058550514 #################### ---------------End of Score_sub-VP SDE!!!--------------- #################### <class 'score_sde.trainer.Trainer'> md_type:{}, num_cells:{} vpsde 128 ------the method of ScoreGrad------is:vpsde CRPS: 0.2601687852986397 ND: 0.2065687765416458 NRMSE: 0.4146478438178124

CRPS-Sum: 0.0631504743008402 ND-Sum: 0.07544346916666013 NRMSE-Sum: 0.09444306545666831 #################### CRPS: 0.26008546358768597 ND: 0.20668198713469552 NRMSE: 0.41447197152792803

CRPS-Sum: 0.06264833847738471 ND-Sum: 0.07464409053315567 NRMSE-Sum: 0.09410747104196882 #################### CRPS: 0.26011952049351533 ND: 0.20650743754039277 NRMSE: 0.4151138381804453

CRPS-Sum: 0.06336520224541081 ND-Sum: 0.07484473851734613 NRMSE-Sum: 0.09508764091346383 #################### CRPS: 0.2601093999257114 ND: 0.2070753447062323 NRMSE: 0.4156746549973178

CRPS-Sum: 0.06409688970732619 ND-Sum: 0.07608443514622047 NRMSE-Sum: 0.09646779021487212 #################### CRPS: 0.26014712930773454 ND: 0.20705630582578188 NRMSE: 0.4156198601566928

CRPS-Sum: 0.06437298935610779 ND-Sum: 0.07621495594969595 NRMSE-Sum: 0.09620030446145965 #################### CRPS: 0.26027109664899484 ND: 0.20749231277297567 NRMSE: 0.4150930471801318

CRPS-Sum: 0.06324052632724432 ND-Sum: 0.07477818347451037 NRMSE-Sum: 0.09546541572941235 #################### CRPS: 0.26016797186718454 ND: 0.20662959076103335 NRMSE: 0.415340985320016

CRPS-Sum: 0.06412325424883036 ND-Sum: 0.07548524075831112 NRMSE-Sum: 0.09615459717651853 #################### CRPS: 0.26012589908863254 ND: 0.2064458184607882 NRMSE: 0.4149919184477151

CRPS-Sum: 0.0635993745573441 ND-Sum: 0.07649693017656356 NRMSE-Sum: 0.09510302895584935 #################### CRPS: 0.26020901522750806 ND: 0.20726866151194015 NRMSE: 0.4152586950977804

CRPS-Sum: 0.06365321078722078 ND-Sum: 0.07566832520599179 NRMSE-Sum: 0.09591685470429513 #################### CRPS: 0.2602277439680248 ND: 0.2069696877340713 NRMSE: 0.4155180711548777

CRPS-Sum: 0.06445920060385012 ND-Sum: 0.07682302215218939 NRMSE-Sum: 0.09691541679794838 #################### ---------------End of Score_VP SDE!!!--------------- #################### <class 'score_sde.trainer.Trainer'> md_type:{}, num_cells:{} vesde 128 ------the method of ScoreGrad------is:vesde CRPS: 0.10562849945602379 ND: 0.12612535290559584 NRMSE: 0.39687831730374046

CRPS-Sum: 0.04026621168092205 ND-Sum: 0.045703408786634356 NRMSE-Sum: 0.058602150074698 #################### CRPS: 0.10559970110983598 ND: 0.12607075631692474 NRMSE: 0.396524732801783

CRPS-Sum: 0.03994858483925993 ND-Sum: 0.04556137465950782 NRMSE-Sum: 0.05824216410830469 #################### CRPS: 0.10563661308538486 ND: 0.12613660622069725 NRMSE: 0.3965948146347445

CRPS-Sum: 0.0403951607849624 ND-Sum: 0.04601224068187979 NRMSE-Sum: 0.05882258022674767 #################### CRPS: 0.10561753362854286 ND: 0.12620421329135686 NRMSE: 0.3965220516234778

CRPS-Sum: 0.04022938414517865 ND-Sum: 0.04585492344153036 NRMSE-Sum: 0.05854503273385719 #################### CRPS: 0.10555488766171948 ND: 0.12612781608905627 NRMSE: 0.39647952141138276

CRPS-Sum: 0.040197161340561526 ND-Sum: 0.04570573513354915 NRMSE-Sum: 0.05864927328021731 #################### CRPS: 0.10554475528483076 ND: 0.12607622511223973 NRMSE: 0.3965879161123644

CRPS-Sum: 0.040051935064548454 ND-Sum: 0.045734251455310765 NRMSE-Sum: 0.058516007066785995 #################### CRPS: 0.10559282853266168 ND: 0.12613668131462394 NRMSE: 0.3964970791416158

CRPS-Sum: 0.04028174649377698 ND-Sum: 0.04569995505967803 NRMSE-Sum: 0.05858186095592334 #################### CRPS: 0.10559688846676331 ND: 0.12613328154874429 NRMSE: 0.39692492945265184

CRPS-Sum: 0.04001235796295822 ND-Sum: 0.04553677220144473 NRMSE-Sum: 0.05840052935157173 #################### CRPS: 0.10557811679930444 ND: 0.12612107599226524 NRMSE: 0.3962595877288983

CRPS-Sum: 0.04053670254231536 ND-Sum: 0.04618436424013901 NRMSE-Sum: 0.05908096204790365 #################### CRPS: 0.10559575354097123 ND: 0.12613047185732584 NRMSE: 0.3967996292863638

CRPS-Sum: 0.040034610345121534 ND-Sum: 0.045552295738366576 NRMSE-Sum: 0.05851000603541668 #################### ---------------End of Score_VE SDE!!!--------------- ####################

yantijin commented 1 year ago

Hi~,I have got this and I will download the weights and check. So give me some time and I will update after checking!

yantijin commented 1 year ago

Hi~, thanks for your attention! Firstly, the hyperparameter settings for different datasets are different. We find that beta1, beta2, scale, train epochs, and the dimension of feature states are very important for the forecasting performance. In addition, we also find that the results on _trafficnips are more sensitive to hyperparameters than the other datasets in the experiments. Here are some train weights and logs of ScoreGrad on traffic. And I hope it will help. Google Drive

I have been using these hyperparameters for training in subvpsde model, and while some results match those reported in the paper, others are worse. I was wondering if you could provide any insights into why this might be the case?

I have attached my code and results for reference. Any help or suggestions would be greatly appreciated.

Thank you for your time and assistance.

estimator = ScoreGradEstimator( input_size=3856, freq='H', prediction_length=dataset.metadata.prediction_length, target_dim=int(dataset.metadata.feat_static_cat[0].cardinality), num_layers=2, num_cells=128, cell_type='GRU', num_parallel_samples=config.num_parallel_samples, dropout_rate=0.1, conditioning_length=100, diff_steps=100, beta_min=0.01, beta_end=10, residual_layers=8, residual_channels=8, dilation_cycle_length=2, scaling=False, md_type='subvpsde', continuous=True, reduce_mean=True, likelihood_weighting=False, config=config, trainer=Trainer( epochs=20, batch_size=64, num_batches_per_epoch=100, learning_rate=1e-3, decay=None, device=config.device, wandb_mode='disabled', config=config) ) the results are as below: ################## the 0-th time, CRPS_sum: 0.04249960683995443 the 1-th time, CRPS_sum: 0.12992904179415346 the 2-th time, CRPS_sum: 0.06056683129822026 the 3-th time, CPRS_sum: 0.03662869992440223 ##################

Hi~, actually we have mentioned the limitation of this method in this paper. The prediction performance cannot be judged by the loss value. So it's sometimes hard to judge the proper training epochs. We indeed find that the performance of ScoreGrad is more sensitve to the hyperparameters on traffic than the other datasets because we have try much more time on this datasets than the others during the experiments.

yantijin commented 1 year ago

Hi~, thanks for your attention! Firstly, the hyperparameter settings for different datasets are different. We find that beta1, beta2, scale, train epochs, and the dimension of feature states are very important for the forecasting performance. In addition, we also find that the results on _trafficnips are more sensitive to hyperparameters than the other datasets in the experiments. Here are some train weights and logs of ScoreGrad on traffic. And I hope it will help. Google Drive

I wanted to take a moment to express my sincere gratitude for your impressive work and generous sharing of your well-trained model and hyperparameters. I am writing to inquire about some discrepancies that I have encountered when using the trained models that you have provided. Specifically, I have noticed differences in the results obtained from the subvpsde model, the vpsde model, and the vesde model as compared to the results reported in the paper.

For the subvpsde model, the average CRPS_sum value that I have obtained is 0.37, vpsde model is 0.063 and vesde model is 0.040.

the output of results as below: <class 'score_sde.trainer.Trainer'> md_type:{}, num_cells:{} subvpsde 128 ------the method of ScoreGrad------is:subvpsde CRPS: 0.2302250774867946 ND: 0.2651221122980262 NRMSE: 0.6391780783176729

CRPS-Sum: 0.38324099030932024 ND-Sum: 0.418940006309006 NRMSE-Sum: 0.46314387977726235 #################### CRPS: 0.22609270020021577 ND: 0.26049269979067335 NRMSE: 0.6300855898901556

CRPS-Sum: 0.37491109214372337 ND-Sum: 0.410651565873248 NRMSE-Sum: 0.45159467238735496 #################### CRPS: 0.22783096526647342 ND: 0.2615933016714219 NRMSE: 0.634868596561549

CRPS-Sum: 0.37804421871533755 ND-Sum: 0.4116224129779668 NRMSE-Sum: 0.4570603758903712 #################### CRPS: 0.2293369145540681 ND: 0.26327622332101513 NRMSE: 0.6362867128251717

CRPS-Sum: 0.3803011485287016 ND-Sum: 0.41274235519271607 NRMSE-Sum: 0.45926252045209287 #################### CRPS: 0.22853455125355113 ND: 0.26399291421867005 NRMSE: 0.6346816895335943

CRPS-Sum: 0.37761255606613836 ND-Sum: 0.4135018203405725 NRMSE-Sum: 0.45687966033078964 #################### CRPS: 0.22484617339930038 ND: 0.25881541790393187 NRMSE: 0.6283151866977141

CRPS-Sum: 0.37411721712527196 ND-Sum: 0.4085678253804295 NRMSE-Sum: 0.45002123280558826 #################### CRPS: 0.22747048312856186 ND: 0.26091447891800174 NRMSE: 0.6350557726104186

CRPS-Sum: 0.3791473036535954 ND-Sum: 0.4129624572360978 NRMSE-Sum: 0.4572322168127546 #################### CRPS: 0.22833773954919803 ND: 0.26298050769601805 NRMSE: 0.6343096974415163

CRPS-Sum: 0.3778284225351686 ND-Sum: 0.4120888925835898 NRMSE-Sum: 0.4567079932542919 #################### CRPS: 0.22682725521098776 ND: 0.2610285602321497 NRMSE: 0.6313838068697308

CRPS-Sum: 0.37649053984627884 ND-Sum: 0.41183696456561686 NRMSE-Sum: 0.45380611791268277 #################### CRPS: 0.22789534208666795 ND: 0.262091913290372 NRMSE: 0.633900837419824

CRPS-Sum: 0.3781326886016716 ND-Sum: 0.41276157828307297 NRMSE-Sum: 0.45602799058550514 #################### ---------------End of Score_sub-VP SDE!!!--------------- #################### <class 'score_sde.trainer.Trainer'> md_type:{}, num_cells:{} vpsde 128 ------the method of ScoreGrad------is:vpsde CRPS: 0.2601687852986397 ND: 0.2065687765416458 NRMSE: 0.4146478438178124

CRPS-Sum: 0.0631504743008402 ND-Sum: 0.07544346916666013 NRMSE-Sum: 0.09444306545666831 #################### CRPS: 0.26008546358768597 ND: 0.20668198713469552 NRMSE: 0.41447197152792803

CRPS-Sum: 0.06264833847738471 ND-Sum: 0.07464409053315567 NRMSE-Sum: 0.09410747104196882 #################### CRPS: 0.26011952049351533 ND: 0.20650743754039277 NRMSE: 0.4151138381804453

CRPS-Sum: 0.06336520224541081 ND-Sum: 0.07484473851734613 NRMSE-Sum: 0.09508764091346383 #################### CRPS: 0.2601093999257114 ND: 0.2070753447062323 NRMSE: 0.4156746549973178

CRPS-Sum: 0.06409688970732619 ND-Sum: 0.07608443514622047 NRMSE-Sum: 0.09646779021487212 #################### CRPS: 0.26014712930773454 ND: 0.20705630582578188 NRMSE: 0.4156198601566928

CRPS-Sum: 0.06437298935610779 ND-Sum: 0.07621495594969595 NRMSE-Sum: 0.09620030446145965 #################### CRPS: 0.26027109664899484 ND: 0.20749231277297567 NRMSE: 0.4150930471801318

CRPS-Sum: 0.06324052632724432 ND-Sum: 0.07477818347451037 NRMSE-Sum: 0.09546541572941235 #################### CRPS: 0.26016797186718454 ND: 0.20662959076103335 NRMSE: 0.415340985320016

CRPS-Sum: 0.06412325424883036 ND-Sum: 0.07548524075831112 NRMSE-Sum: 0.09615459717651853 #################### CRPS: 0.26012589908863254 ND: 0.2064458184607882 NRMSE: 0.4149919184477151

CRPS-Sum: 0.0635993745573441 ND-Sum: 0.07649693017656356 NRMSE-Sum: 0.09510302895584935 #################### CRPS: 0.26020901522750806 ND: 0.20726866151194015 NRMSE: 0.4152586950977804

CRPS-Sum: 0.06365321078722078 ND-Sum: 0.07566832520599179 NRMSE-Sum: 0.09591685470429513 #################### CRPS: 0.2602277439680248 ND: 0.2069696877340713 NRMSE: 0.4155180711548777

CRPS-Sum: 0.06445920060385012 ND-Sum: 0.07682302215218939 NRMSE-Sum: 0.09691541679794838 #################### ---------------End of Score_VP SDE!!!--------------- #################### <class 'score_sde.trainer.Trainer'> md_type:{}, num_cells:{} vesde 128 ------the method of ScoreGrad------is:vesde CRPS: 0.10562849945602379 ND: 0.12612535290559584 NRMSE: 0.39687831730374046

CRPS-Sum: 0.04026621168092205 ND-Sum: 0.045703408786634356 NRMSE-Sum: 0.058602150074698 #################### CRPS: 0.10559970110983598 ND: 0.12607075631692474 NRMSE: 0.396524732801783

CRPS-Sum: 0.03994858483925993 ND-Sum: 0.04556137465950782 NRMSE-Sum: 0.05824216410830469 #################### CRPS: 0.10563661308538486 ND: 0.12613660622069725 NRMSE: 0.3965948146347445

CRPS-Sum: 0.0403951607849624 ND-Sum: 0.04601224068187979 NRMSE-Sum: 0.05882258022674767 #################### CRPS: 0.10561753362854286 ND: 0.12620421329135686 NRMSE: 0.3965220516234778

CRPS-Sum: 0.04022938414517865 ND-Sum: 0.04585492344153036 NRMSE-Sum: 0.05854503273385719 #################### CRPS: 0.10555488766171948 ND: 0.12612781608905627 NRMSE: 0.39647952141138276

CRPS-Sum: 0.040197161340561526 ND-Sum: 0.04570573513354915 NRMSE-Sum: 0.05864927328021731 #################### CRPS: 0.10554475528483076 ND: 0.12607622511223973 NRMSE: 0.3965879161123644

CRPS-Sum: 0.040051935064548454 ND-Sum: 0.045734251455310765 NRMSE-Sum: 0.058516007066785995 #################### CRPS: 0.10559282853266168 ND: 0.12613668131462394 NRMSE: 0.3964970791416158

CRPS-Sum: 0.04028174649377698 ND-Sum: 0.04569995505967803 NRMSE-Sum: 0.05858186095592334 #################### CRPS: 0.10559688846676331 ND: 0.12613328154874429 NRMSE: 0.39692492945265184

CRPS-Sum: 0.04001235796295822 ND-Sum: 0.04553677220144473 NRMSE-Sum: 0.05840052935157173 #################### CRPS: 0.10557811679930444 ND: 0.12612107599226524 NRMSE: 0.3962595877288983

CRPS-Sum: 0.04053670254231536 ND-Sum: 0.04618436424013901 NRMSE-Sum: 0.05908096204790365 #################### CRPS: 0.10559575354097123 ND: 0.12613047185732584 NRMSE: 0.3967996292863638

CRPS-Sum: 0.040034610345121534 ND-Sum: 0.045552295738366576 NRMSE-Sum: 0.05851000603541668 #################### ---------------End of Score_VE SDE!!!--------------- ####################

As the server in our lab is full of jobs now, I need to check the weights on my PC, so the update process might be slow. We have checked the provided weights of ScoreGrad(VESDE) on traffic firstly and the results matches the results reported in the paper. The config is

import ml_collections
import torch

def get_default_configs():
    config = ml_collections.ConfigDict()

    # sampling
    config.sampling = sampling = ml_collections.ConfigDict()
    sampling.noise_removal = True # denoising at the last step.
    sampling.snr=0.16 # ratio in annealed Langivan sampling
    sampling.n_steps_each = 1
    sampling.probability_flow = False

    config.training = ml_collections.ConfigDict()

    config.modeling = modeling = ml_collections.ConfigDict()

    modeling.residual_layers = 8
    modeling.residual_channels = 8
    modeling.dilation_cycle_length = 2
    modeling.scaling = False # traffic doesn't need scaling

    config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    config.input_size = 3856 # 963 * 3 + 963 + 4
    config.learning_rate = 1e-3
    config.num_layers = 2
    config.num_cells = 128
    config.num_parallel_samples = 100
    config.dropout_rate = 0.1
    config.conditioning_length = 100
    config.num_batches_per_epoch = 100

    return config

Note that the noise removal is set as True. And you can check the commen configs with yours And the specific config file smld_cont.py is

from configs.traffic.traffic_config import get_default_configs

def get_traffic_smld_cont_config():
    config = get_default_configs()

    config.weight_decay = None
    config.reduce_mean = True
    config.likelihood_weighting = False
    config.batch_size = 32
    config.epochs = 30

    modeling = config.modeling
    modeling.num_scales = 200
    modeling.beta_min = 0.01
    modeling.beta_max = 15
    modeling.md_type = 'vesde'

    sampling = config.sampling
    sampling.method = 'pc'
    sampling.predictor = 'reverse_diffusion'
    sampling.corrector = 'langevin'

    training = config.training
    training.continuous = True
    training.seed = 123

    config.train = False
    config.save = False
    config.path = './model/traffic_smld_c.pkl'

    return config

The results obtained on my PC are:

CRPS: 0.10539966620295366 ND: 0.12592090852075508 NRMSE: 0.3954344121841787

CRPS-Sum: 0.038415752930585266 ND-Sum: 0.04364691766372566 NRMSE-Sum: 0.056334031434276774

yantijin commented 1 year ago

The provided weights of ScoreGrad(subVPSDE) is verified now. The config file subvpsde.py is

from configs.traffic.traffic_config import get_default_configs

def get_traffic_subvpsde_config():
    config = get_default_configs()

    config.weight_decay = None
    config.reduce_mean = True
    config.likelihood_weighting = False
    config.batch_size = 32
    config.epochs = 50

    modeling = config.modeling
    modeling.num_scales = 200
    modeling.beta_min = 0.01
    modeling.beta_max = 15 
    modeling.md_type = 'subvpsde'

    sampling = config.sampling
    sampling.method = 'pc'
    sampling.predictor = 'euler_maruyama'
    sampling.corrector = 'none'

    training = config.training
    training.continuous = True
    training.seed = 3407

    config.train = False
    config.save = False
    config.path = './model/traffic_subvpsde.pkl'

    return config

The obtained results on my PC are:

CRPS: 0.10000249099794856 ND: 0.11853302451532367 NRMSE: 0.40078522603670796

CRPS-Sum: 0.04169813552265052 ND-Sum: 0.04889053467565276 NRMSE-Sum: 0.06718607722245465