awslabs / gluonts

Probabilistic time series modeling in Python
https://ts.gluon.ai
Apache License 2.0
4.55k stars 748 forks source link

Median prediction is plotted one index shift to the left #2205

Open zhichenggeng opened 2 years ago

zhichenggeng commented 2 years ago

Description

The plot of median prediction is shifted to the left by one index, while 90% and 50% prediction interval plots are correct. Potential reason is that the data is resampled by week, which might cause some problem on the index.

To Reproduce

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def generate_single_ts(date_range, item_id=None) -> pd.DataFrame:
    """create sum of `n_f` sin/cos curves with random scale and phase."""
    n_f = 2
    period = np.array([24 / (i + 1) for i in range(n_f)]).reshape(1, n_f)
    scale = np.random.normal(1, 0.3, size=(1, n_f))
    phase = 2 * np.pi * np.random.uniform(size=(1, n_f))
    periodic_f = lambda x: scale * np.sin(np.pi * x / period + phase)

    t = np.arange(0, len(date_range)).reshape(-1, 1)
    target = periodic_f(t).sum(axis=1) + np.random.normal(0, 0.1, size=len(t))
    ts = pd.DataFrame({"target": target}, index=date_range)
    if item_id is not None:
        ts["item_id"] = item_id
    return ts

prediction_length, freq = 2, "1D"
T = 365 * prediction_length
date_range = pd.date_range("2021-01-01", periods=T, freq=freq)
ts = generate_single_ts(date_range)

ts = ts.resample('W', label='left').sum()

from gluonts.model.deepar import DeepAREstimator
from gluonts.mx import Trainer
from gluonts.evaluation import make_evaluation_predictions, Evaluator

prediction_length, freq = 4, "W"
estimator = DeepAREstimator(
    freq=freq, prediction_length=prediction_length, trainer=Trainer(epochs=1)
)

from gluonts.dataset.pandas import PandasDataset

train = PandasDataset(ts[:-prediction_length], target="target", freq=freq)
test = PandasDataset(ts, target="target", freq=freq)

predictor = estimator.train(train)
forecast_it, ts_it = make_evaluation_predictions(dataset=test, predictor=predictor)

forecasts = list(forecast_it)
tss = list(ts_it)
forecast_entry = forecasts[0]
ts_entry = tss[0]

fig, ax = plt.subplots(1, 1, figsize=(20, 8))
ts_entry[-prediction_length * 4:].plot(ax=ax)
forecast_entry.plot(prediction_intervals=(50, 90), color="g")

Error message or code output

Environment

jaheba commented 2 years ago

Thanks for the reproducible example!

This seems to be an issue when mixing Period and Timestamp columns. If I change the plotting of the true values, it works:

ts_entry[-prediction_length * 4:].to_timestamp().plot(ax=ax)
Screenshot 2022-08-09 at 20 23 36

I don't know (yet) why this happens -- but at least there is a workaround.

I was thinking about having some more plot utilities to a) make generating these kind of plots easier and b) avoid issues like this one.

zhichenggeng commented 2 years ago

It works! Thanks for your help.

Looking forward to more generalized plotting tools.

lostella commented 2 years ago

@jaheba as a fix, would it make sense to bake .to_timestamp in any utils (hopefully it’s a single place) we use to turn DataEntry into pandas?

lostella commented 2 years ago

What's interesting is that inverting the order of plotting also solves the issue:

forecast_entry.plot(prediction_intervals=(50, 90), color="g")
ts_entry[-prediction_length * 4:].plot(ax=ax)

image