There is currently a bug where the model inputs may be routed incorrect by the forecast generator. This effectively results in past_feat_dynamic_real and past_feat_dynamic_cat being ignored by the TFT model.
MWE:
from unittest import mock
import numpy as np
import pandas as pd
from gluonts.torch.model.tft import TemporalFusionTransformerEstimator
freq = "D"
N = 50
data = [
{"target": np.arange(N), "past_feat_dynamic_real": np.random.rand(1, N).astype("float32"), "start": pd.Period("2020-01-01", freq=freq)}
]
predictor = TemporalFusionTransformerEstimator(prediction_length=1, freq=freq, past_dynamic_dims=[1], trainer_kwargs={"max_epochs": 1}).train(data)
with mock.patch("gluonts.torch.model.tft.module.TemporalFusionTransformerModel._preprocess") as mock_fwd:
try:
fcst = list(predictor.predict(data))
except:
pass
call_kwargs = mock_fwd.call_args[1]
call_kwargs["feat_dynamic_cat"]
# tensor([[[0.8073]]])
call_kwargs["past_feat_dynamic_real"]
# None
The bug occurs because model inputs are passed as positional arguments instead of keyword arguments.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup
Fixes #3185
Description of changes:
There is currently a bug where the model inputs may be routed incorrect by the forecast generator. This effectively results in
past_feat_dynamic_real
andpast_feat_dynamic_cat
being ignored by the TFT model.MWE:
The bug occurs because model inputs are passed as positional arguments instead of keyword arguments.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup