Closed parora07234 closed 2 years ago
Try this sample script.
import matplotlib.pyplot as plt
from gluonts.dataset.repository.datasets import get_dataset, dataset_recipes
from gluonts.dataset.util import to_pandas
dataset = get_dataset("electricity_nips", regenerate=False)
from gluonts.model.seq2seq import MQCNNEstimator
from gluonts.mx import Trainer
estimator = MQCNNEstimator(
prediction_length=dataset.metadata.prediction_length,
context_length=4* 24,
freq=dataset.metadata.freq,
trainer=Trainer(
ctx="cpu",
epochs=3,
learning_rate=1e-3,
num_batches_per_epoch=100),
channels_seq=[30,30,30],
dilation_seq=[1,3,5],
kernel_size_seq=[7,3,3],
quantiles=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
)
predictor = estimator.train(dataset.train)
from gluonts.evaluation import make_evaluation_predictions
forecast_it, ts_it = make_evaluation_predictions(
dataset=dataset.test, # test dataset
predictor=predictor, # predictor
num_samples=100, # number of sample paths we want for evaluation
)
forecast_entry = next(iter(forecast_it))
ts_entry = next(iter(ts_it))
forecast_entry.plot()
plt.plot(ts_entry[-96:])
plt.show()
Hope this helps.
Hey @parora07234, does @dai-ichiro suggestion above help? If so, this issue could be closed
I have plotted probabilistic forecasts using DeepAr and DeepState (prediction intervals) , but now I want to use MQCNN/MQRNN , but could not understand how to plot quantile forecast.
Error message or code output
After running forecast, I get error as below
Environment
)