Fix incorrect input routing for models #3186

Closed shchur closed 4 weeks ago

shchur commented 4 weeks ago

Fixes #3185

Description of changes:

There is currently a bug where the model inputs may be routed incorrect by the forecast generator. This effectively results in past_feat_dynamic_real and past_feat_dynamic_cat being ignored by the TFT model.


from unittest import mock
import numpy as np
import pandas as pd
from gluonts.torch.model.tft import TemporalFusionTransformerEstimator

freq = "D"
N = 50
data = [
    {"target": np.arange(N), "past_feat_dynamic_real": np.random.rand(1, N).astype("float32"), "start": pd.Period("2020-01-01", freq=freq)}

predictor = TemporalFusionTransformerEstimator(prediction_length=1, freq=freq, past_dynamic_dims=[1], trainer_kwargs={"max_epochs": 1}).train(data)

with mock.patch("gluonts.torch.model.tft.module.TemporalFusionTransformerModel._preprocess") as mock_fwd:
        fcst = list(predictor.predict(data))
    call_kwargs = mock_fwd.call_args[1]

# tensor([[[0.8073]]])
# None

The bug occurs because model inputs are passed as positional arguments instead of keyword arguments.

