sktime / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.98k stars 631 forks source link

Interpreting DeepAR.predict() #1521

Open jambudipa opened 8 months ago

jambudipa commented 8 months ago

I have a DeepAR model that I have built from the sample. I am not really sure how to interpret the results of the predict() method though:

max_ = full_data.time_idx.max()
future_data = pd.DataFrame({
    'time_idx': range(max_ + 1, max_ + 1 + self.max_prediction_length),
    'price': 0,  # Doesn't matter
    'volume': 0,
    'group_id': asset_pair}).astype(pair_data.dtypes)

# concat historical and future data
full_data = pd.concat([full_data, future_data])

# create new TimeSeriesDataSet with full data
full_dataset = TimeSeriesDataSet(
    full_data,
    time_idx="time_idx",
    target="price",
    group_ids=["group_id"],
    time_varying_unknown_reals=["price"],
    min_prediction_idx=training_cutoff + 1,
    min_encoder_length=self.max_encoder_length,
    max_encoder_length=self.max_encoder_length,
    max_prediction_length=self.max_prediction_length)

full_dataloader = full_dataset.to_dataloader(
    train=False,
    batch_size=len(full_dataset),
    num_workers=0
)

future_predictions = best_model.predict(full_dataloader)
future_predictions_df = pd.DataFrame(future_predictions.cpu().numpy())

The data is pairs of time, price, yet the future_predictions is a tensor of 36x37. 36 is the forecasting length.

Struggling to understand what this represents. With an ordinary model, you would expect 36x1 perhaps; I know that DeepAR predicts a range, is it the case that each time index just so happens to include as many predictions as there are time indexes?

ivanightingale commented 7 months ago

I couldn't reproduce your problem. I tried a small dataframe, the DeepAR model from the sample and your settings, and got a sequence of scalar predictions. Do you mind testing with a small dataframe and, if the problem persists, copy the entire reproducible code snippet here?

Hint: adding the word "Python" after your "```" makes your code snippet more legible.