Nixtla / mlforecast

Scalable machine 🤖 learning for time series forecasting.
https://nixtlaverse.nixtla.io/mlforecast
Apache License 2.0
789 stars 74 forks source link

Not enough models trained in cross_validation with fitted=True and horizon > 9 #331

Closed adriaanvh1 closed 3 months ago

adriaanvh1 commented 3 months ago

What happened + What you expected to happen

Bug When cross validating with fitted=True and max_horizon longer than 9, an IndexError is raised.

Expected Behaviour Cross validation runs without error.

Useful information Traceback

{
    "name": "IndexError",
    "message": "index 10 is out of bounds for axis 1 with size 10",
    "stack": "---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[3], line 20
     18 # Get in-sample cross-validation predictions
     19 h = 14
---> 20 _cv_results = fcst.cross_validation(
     21     df=target,
     22     n_windows=4,
     23     h=h,
     24     max_horizon=h,
     25     fitted=True,
     26 )

File /opt/composor-env/lib/python3.10/site-packages/mlforecast/forecast.py:848, in MLForecast.cross_validation(self, df, n_windows, h, id_col, time_col, target_col, step_size, static_features, dropna, keep_last_n, refit, max_horizon, before_predict_callback, after_predict_callback, prediction_intervals, level, input_size, fitted, as_numpy)
    846 should_fit = i_window == 0 or (refit > 0 and i_window % refit == 0)
    847 if should_fit:
--> 848     self.fit(
    849         train,
    850         id_col=id_col,
    851         time_col=time_col,
    852         target_col=target_col,
    853         static_features=static_features,
    854         dropna=dropna,
    855         keep_last_n=keep_last_n,
    856         max_horizon=max_horizon,
    857         prediction_intervals=prediction_intervals,
    858         fitted=fitted,
    859         as_numpy=as_numpy,
    860     )
    861     self.cv_models_.append(self.models_)
    862     if fitted:

File /opt/composor-env/lib/python3.10/site-packages/mlforecast/forecast.py:526, in MLForecast.fit(self, df, id_col, time_col, target_col, static_features, dropna, keep_last_n, max_horizon, prediction_intervals, fitted, as_numpy)
    524 self.fit_models(X, y)
    525 if fitted:
--> 526     fitted_values = self._compute_fitted_values(
    527         base=base,
    528         X=X,
    529         y=y,
    530         id_col=id_col,
    531         time_col=time_col,
    532         target_col=target_col,
    533         max_horizon=max_horizon,
    534     )
    535     fitted_values = ufp.drop_index_if_pandas(fitted_values)
    536     self.fcst_fitted_values_ = fitted_values

File /opt/composor-env/lib/python3.10/site-packages/mlforecast/forecast.py:411, in MLForecast._compute_fitted_values(self, base, X, y, id_col, time_col, target_col, max_horizon)
    408 for horizon in range(max_horizon):
    409     horizon_base = ufp.copy_if_pandas(base, deep=True)
    410     horizon_base = ufp.assign_columns(
--> 411         horizon_base, target_col, y[:, horizon]
    412     )
    413     horizon_fitted_values.append(horizon_base)
    414 for name, horizon_models in self.models_.items():

IndexError: index 10 is out of bounds for axis 1 with size 10"
}

Versions / Dependencies

mlforecast: 0.12.0 datasetsforecast: 0.0.8 lightgbm: 4.3.0

python: 3.10.14 OS: macOS Monterey v12.6

Reproduction script

from mlforecast import MLForecast
from datasetsforecast.m5 import M5
from mlforecast.utils import PredictionIntervals
from lightgbm import LGBMRegressor

# Get data and take subset
target, _exogenous, _static_vars = M5.load("./data_dir")
unique_ids = target["unique_id"].unique()[::100]
target = target[target["unique_id"].isin(unique_ids)]

# Define model
fcst = MLForecast(
    models=[LGBMRegressor(n_estimators=2)],
    freq="D",
    lags=[1],
)

# Get in-sample cross-validation predictions
h = 14
_cv_results = fcst.cross_validation(
    df=target,
    n_windows=4,
    h=h,
    max_horizon=h,
    fitted=True,
)

Issue Severity

High: It blocks me from completing my task.