awslabs / gluonts

Probabilistic time series modeling in Python
https://ts.gluon.ai
Apache License 2.0
4.64k stars 755 forks source link

Error calling make_evaluation_predictions with TFT using past_feat_dynamic_real after update to 0.15.0 #3185

Closed wbinek closed 5 months ago

wbinek commented 6 months ago

Description

The issue is caused by changes to QuantileForecastGenerator. Some of the inputs to TFT forward method are optional. New QuantileForecastGenerator changed the way the input is passed to the model, breaking cases where some additional features are passed, but not all. In my case I had problems when dataset contains past_feat_dynamic_real without feat_dynamic_cat

New implementation of QuantileForecastGenerator calls:

for batch in inference_data_loader:
    inputs = select(input_names, batch, ignore_missing=True)
    (outputs,), loc, scale = prediction_net(*inputs.values())

This causes problems if some of the input_names are not present in dataset. The inputs are assigned to wrong function parameters. The problem can be fixed by changing the last line to:

(outputs,), loc, scale = prediction_net(**inputs)

In the such case the names are passed along with the values, so that the inputs and parameters don't get mixed.

To Reproduce

(Please provide minimal example of code snippet that reproduces the error. For existing examples, please provide link.)

# Using dataset like this will cause the described problem
static_feature_cols = ['sf1', 'sf2']
dynamic_feature_cols = ['df1', 'df2']

test = PandasDataset.from_long_dataframe(df, 
                                       target='target',
                                       item_id = "series_idx",
                                       timestamp= "date",
                                       static_feature_columns=static_feature_cols,
                                       past_feat_dynamic_real=dynamic_feature_cols,
                                       freq = "1D",
                                       )

Error message or code output

TypeError: linear(): argument 'input' (position 1) must be Tensor, not NoneType

Environment

(Add as much information about your environment as possible, e.g. dependencies versions.)