awslabs / gluonts

Probabilistic time series modeling in Python
https://ts.gluon.ai
Apache License 2.0
4.41k stars 740 forks source link

Fix incorrect input routing for models #3186

Closed shchur closed 4 weeks ago

shchur commented 4 weeks ago

Fixes #3185

Description of changes:

There is currently a bug where the model inputs may be routed incorrect by the forecast generator. This effectively results in past_feat_dynamic_real and past_feat_dynamic_cat being ignored by the TFT model.

MWE:

from unittest import mock
import numpy as np
import pandas as pd
from gluonts.torch.model.tft import TemporalFusionTransformerEstimator

freq = "D"
N = 50
data = [
    {"target": np.arange(N), "past_feat_dynamic_real": np.random.rand(1, N).astype("float32"), "start": pd.Period("2020-01-01", freq=freq)}
]

predictor = TemporalFusionTransformerEstimator(prediction_length=1, freq=freq, past_dynamic_dims=[1], trainer_kwargs={"max_epochs": 1}).train(data)

with mock.patch("gluonts.torch.model.tft.module.TemporalFusionTransformerModel._preprocess") as mock_fwd:
    try:
        fcst = list(predictor.predict(data))
    except:
        pass   
    call_kwargs = mock_fwd.call_args[1]

call_kwargs["feat_dynamic_cat"]  
# tensor([[[0.8073]]])
call_kwargs["past_feat_dynamic_real"]  
# None

The bug occurs because model inputs are passed as positional arguments instead of keyword arguments.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup