Nixtla / neuralforecast

Scalable and user friendly neural :brain: forecasting algorithms.
https://nixtlaverse.nixtla.io/neuralforecast
Apache License 2.0
3.1k stars 357 forks source link

`maximum size for tensor at dimension` for NBEATS model #1193

Open j-adamczyk opened 1 week ago

j-adamczyk commented 1 week ago

What happened + What you expected to happen

I am performing a pretty standard long-term forecast:

df_train = wide_to_long_df(df_train)
df_test = wide_to_long_df(df_test)

nbeats = NBEATS(
    h=len(df_test),
    input_size=window_length,
    max_steps=100,
)
nbeats = NeuralForecast(models=[nbeats], freq="D")
nbeats.fit(
    df_train,
    val_size=int(0.2 * len(df_train)),
)
y_pred_nbeats = nbeats.predict(df_test)
print(long_to_wide_df(y_pred_nbeats))
y_pred_nbeats = long_to_wide_df(y_pred_nbeats)

mase_nbeats = mean_absolute_scaled_error(df_test, y_pred_nbeats, y_train=df_train)

print(f"MASE N-BEATS: {mase_nbeats:.2f}")

df_test has length 1480. When I set h for value over 1178, I get an error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[177], line 2
      1 print("Italian pasta:")
----> 2 evaluate_mlp_models(df_pasta_train, df_pasta_test, window_length=20)
      4 print()
      6 print("Polish energy:")

Cell In[176], line 18, in evaluate_mlp_models(df_train, df_test, window_length)
     12 nbeats = NBEATS(
     13     h=len(df_test),
     14     input_size=window_length,
     15     max_steps=100,
     16 )
     17 nbeats = NeuralForecast(models=[nbeats], freq="D")
---> 18 nbeats.fit(
     19     df_train,
     20     val_size=int(0.2 * len(df_train)),
     21 )
     22 y_pred_nbeats = nbeats.predict(df_test)
     23 print(long_to_wide_df(y_pred_nbeats))

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/neuralforecast/core.py:544, in NeuralForecast.fit(self, df, static_df, val_size, sort_df, use_init_models, verbose, id_col, time_col, target_col, distributed_config)
    541     self._reset_models()
    543 for i, model in enumerate(self.models):
--> 544     self.models[i] = model.fit(
    545         self.dataset, val_size=val_size, distributed_config=distributed_config
    546     )
    548 self._fitted = True

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/neuralforecast/common/_base_windows.py:661, in BaseWindows.fit(self, dataset, val_size, test_size, random_seed, distributed_config)
    632 def fit(
    633     self,
    634     dataset,
   (...)
    638     distributed_config=None,
    639 ):
    640     """Fit.
    641 
    642     The `fit` method, optimizes the neural network's weights using the
   (...)
    659     `test_size`: int, test size for temporal cross-validation.<br>
    660     """
--> 661     return self._fit(
    662         dataset=dataset,
    663         batch_size=self.batch_size,
    664         valid_batch_size=self.valid_batch_size,
    665         val_size=val_size,
    666         test_size=test_size,
    667         random_seed=random_seed,
    668         distributed_config=distributed_config,
    669     )

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/neuralforecast/common/_base_model.py:356, in BaseModel._fit(self, dataset, batch_size, valid_batch_size, val_size, test_size, random_seed, shuffle_train, distributed_config)
    354 model = self
    355 trainer = pl.Trainer(**model.trainer_kwargs)
--> 356 trainer.fit(model, datamodule=datamodule)
    357 model.metrics = trainer.callback_metrics
    358 model.__dict__.pop("_trainer", None)

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    536 self.state.status = TrainerStatus.RUNNING
    537 self.training = True
--> 538 call._call_and_handle_interrupt(
    539     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    540 )

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     45     if trainer.strategy.launcher is not None:
     46         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47     return trainer_fn(*args, **kwargs)
     49 except _TunerExitException:
     50     _call_teardown_hook(trainer)

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    567 assert self.state.fn is not None
    568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    569     self.state.fn,
    570     ckpt_path,
    571     model_provided=True,
    572     model_connected=self.lightning_module is not None,
    573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
    576 assert self.state.stopped
    577 self.training = False

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path)
    976 self._signal_connector.register_signal_handlers()
    978 # ----------------------------
    979 # RUN THE TRAINER
    980 # ----------------------------
--> 981 results = self._run_stage()
    983 # ----------------------------
    984 # POST-Training CLEAN UP
    985 # ----------------------------
    986 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1023, in Trainer._run_stage(self)
   1021 if self.training:
   1022     with isolate_rng():
-> 1023         self._run_sanity_check()
   1024     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
   1025         self.fit_loop.run()

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1052, in Trainer._run_sanity_check(self)
   1049 call._call_callback_hooks(self, "on_sanity_check_start")
   1051 # run eval step
-> 1052 val_loop.run()
   1054 call._call_callback_hooks(self, "on_sanity_check_end")
   1056 # reset logger connector

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:178, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    176     context_manager = torch.no_grad
    177 with context_manager():
--> 178     return loop_run(self, *args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:135, in _EvaluationLoop.run(self)
    133     self.batch_progress.is_last_batch = data_fetcher.done
    134     # run step hooks
--> 135     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
    136 except StopIteration:
    137     # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
    138     break

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:396, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
    390 hook_name = "test_step" if trainer.testing else "validation_step"
    391 step_args = (
    392     self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
    393     if not using_dataloader_iter
    394     else (dataloader_iter,)
    395 )
--> 396 output = call._call_strategy_hook(trainer, hook_name, *step_args)
    398 self.batch_progress.increment_processed()
    400 if using_dataloader_iter:
    401     # update the hook kwargs now that the step method might have consumed the iterator

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:319, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    316     return None
    318 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 319     output = fn(*args, **kwargs)
    321 # restore current_fx when nested context
    322 pl_module._current_fx_name = prev_fx_name

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:411, in Strategy.validation_step(self, *args, **kwargs)
    409 if self.model != self.lightning_module:
    410     return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 411 return self.lightning_module.validation_step(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/neuralforecast/common/_base_windows.py:492, in BaseWindows.validation_step(self, batch, batch_idx)
    489     return np.nan
    491 # TODO: Hack to compute number of windows
--> 492 windows = self._create_windows(batch, step="val")
    493 n_windows = len(windows["temporal"])
    494 y_idx = batch["y_idx"]

File ~/.cache/pypoetry/virtualenvs/ml-time-series-forecasting-course-solution-27NYmbH6-py3.10/lib/python3.10/site-packages/neuralforecast/common/_base_windows.py:248, in BaseWindows._create_windows(self, batch, step, w_idxs)
    245     padder_right = nn.ConstantPad1d(padding=(0, self.h), value=0)
    246     temporal = padder_right(temporal)
--> 248 windows = temporal.unfold(
    249     dimension=-1, size=window_size, step=predict_step_size
    250 )
    252 # [batch, channels, windows, window_size] 0, 1, 2, 3
    253 # -> [batch * windows, window_size, channels] 0, 2, 3, 1
    254 windows_per_serie = windows.shape[2]

RuntimeError: maximum size for tensor at dimension 2 is 1188 but size is 1480

As such, I can't run NBEATS for long-term forecasting. A similar issue has been reported for TimesNet: https://github.com/Nixtla/neuralforecast/issues/1099. Recursive forecast works, but is quite inconvenient.

Versions / Dependencies

Python 3.10.15, Ubuntu 24.04.

Full pip list output:

Package                   Version
------------------------- --------------
adagio                    0.2.6
aiohappyeyeballs          2.4.3
aiohttp                   3.10.10
aiosignal                 1.3.1
alembic                   1.13.3
anyio                     4.6.2.post1
appdirs                   1.4.4
argon2-cffi               23.1.0
argon2-cffi-bindings      21.2.0
arrow                     1.3.0
asttokens                 2.4.1
async-lru                 2.0.4
async-timeout             4.0.3
attrs                     24.2.0
babel                     2.16.0
beautifulsoup4            4.12.3
bleach                    6.2.0
certifi                   2024.8.30
cffi                      1.17.1
charset-normalizer        3.4.0
click                     8.1.7
cloudpickle               3.1.0
colorlog                  6.9.0
comm                      0.2.2
contourpy                 1.3.0
coreforecast              0.0.13.1
cycler                    0.12.1
Cython                    3.0.11
debugpy                   1.8.7
decorator                 5.1.1
defusedxml                0.7.1
exceptiongroup            1.2.2
executing                 2.1.0
fastjsonschema            2.20.0
filelock                  3.16.1
fonttools                 4.54.1
fqdn                      1.5.1
frozenlist                1.5.0
fs                        2.4.16
fsspec                    2024.10.0
fugue                     0.9.1
greenlet                  3.1.1
h11                       0.14.0
holidays                  0.59
httpcore                  1.0.6
httpx                     0.27.2
idna                      3.10
ipykernel                 6.29.5
ipython                   8.29.0
ipywidgets                8.1.5
isoduration               20.11.0
jedi                      0.19.1
Jinja2                    3.1.4
joblib                    1.4.2
json5                     0.9.25
jsonpointer               3.0.0
jsonschema                4.23.0
jsonschema-specifications 2024.10.1
jupyter                   1.1.1
jupyter_client            8.6.3
jupyter-console           6.6.3
jupyter_core              5.7.2
jupyter-events            0.10.0
jupyter-lsp               2.2.5
jupyter_server            2.14.2
jupyter_server_terminals  0.5.3
jupyterlab                4.3.0
jupyterlab_pygments       0.3.0
jupyterlab_server         2.27.3
jupyterlab_widgets        3.0.13
kiwisolver                1.4.7
lightning-utilities       0.11.8
llvmlite                  0.43.0
Mako                      1.3.6
MarkupSafe                3.0.2
matplotlib                3.9.2
matplotlib-inline         0.1.7
mistune                   3.0.2
mpmath                    1.3.0
msgpack                   1.1.0
multidict                 6.1.0
nbclient                  0.10.0
nbconvert                 7.16.4
nbformat                  5.10.4
nest-asyncio              1.6.0
networkx                  3.4.2
neuralforecast            1.7.5
notebook                  7.0.7
notebook_shim             0.2.4
numba                     0.60.0
numpy                     1.26.4
nvidia-cublas-cu12        12.4.5.8
nvidia-cuda-cupti-cu12    12.4.127
nvidia-cuda-nvrtc-cu12    12.4.127
nvidia-cuda-runtime-cu12  12.4.127
nvidia-cudnn-cu12         9.1.0.70
nvidia-cufft-cu12         11.2.1.3
nvidia-curand-cu12        10.3.5.147
nvidia-cusolver-cu12      11.6.1.9
nvidia-cusparse-cu12      12.3.1.170
nvidia-nccl-cu12          2.21.5
nvidia-nvjitlink-cu12     12.4.127
nvidia-nvtx-cu12          12.4.127
optuna                    4.0.0
overrides                 7.7.0
packaging                 24.1
pandas                    2.2.3
pandocfilters             1.5.1
parso                     0.8.4
patsy                     0.5.6
pexpect                   4.9.0
pillow                    11.0.0
pip                       24.1
platformdirs              4.3.6
pmdarima                  2.0.4
prometheus_client         0.21.0
prompt_toolkit            3.0.48
propcache                 0.2.0
protobuf                  5.28.3
psutil                    6.1.0
ptyprocess                0.7.0
pure_eval                 0.2.3
pyarrow                   18.0.0
pycparser                 2.22
Pygments                  2.18.0
pyparsing                 3.2.0
python-dateutil           2.9.0.post0
python-json-logger        2.0.7
pytorch-lightning         2.4.0
pytz                      2024.2
PyYAML                    6.0.2
pyzmq                     26.2.0
ray                       2.38.0
referencing               0.35.1
requests                  2.32.3
rfc3339-validator         0.1.4
rfc3986-validator         0.1.1
rpds-py                   0.20.1
scikit-base               0.11.0
scikit-learn              1.5.2
scipy                     1.14.1
seaborn                   0.13.2
Send2Trash                1.8.3
setuptools                75.3.0
six                       1.16.0
sktime                    0.34.0
sniffio                   1.3.1
soupsieve                 2.6
SQLAlchemy                2.0.36
stack-data                0.6.3
statsforecast             1.7.8
statsmodels               0.14.4
sympy                     1.13.1
tensorboardX              2.6.2.2
terminado                 0.18.1
threadpoolctl             3.5.0
tinycss2                  1.4.0
tomli                     2.0.2
torch                     2.5.1
torchmetrics              1.5.1
tornado                   6.4.1
tqdm                      4.66.6
traitlets                 5.14.3
triad                     0.9.8
triton                    3.1.0
types-python-dateutil     2.9.0.20241003
typing_extensions         4.12.2
tzdata                    2024.2
uri-template              1.3.0
urllib3                   2.2.3
utilsforecast             0.2.7
wcwidth                   0.2.13
webcolors                 24.8.0
webencodings              0.5.1
websocket-client          1.8.0
widgetsnbextension        4.0.13
yarl                      1.17.1

Reproduction script

Just run NBEATS example from docs with long data: https://nixtlaverse.nixtla.io/neuralforecast/models.nbeats.html. I used this dataset: https://archive.ics.uci.edu/dataset/611/hierarchical+sales+data.

Code for loading:

import pandas as pd

df_pasta = pd.read_csv("italian_pasta.csv")
for num in [1, 2, 3, 4]:
    df_pasta[f"value_B{num}"] = df_pasta.filter(regex=f"QTY_B{num}*").sum(axis="columns")

df_pasta = df_pasta.set_index(pd.to_datetime(df_pasta["DATE"])).asfreq("d")
df_pasta = df_pasta[["value_B1", "value_B2", "value_B3", "value_B4"]]

Helper functions:

def wide_to_long_df(df: pd.DataFrame) -> pd.DataFrame:
    df = pd.melt(df, ignore_index=False).reset_index(names="date")
    df = df.rename(columns={"variable": "unique_id", "date": "ds", "value": "y"})
    return df

def long_to_wide_df(df: pd.DataFrame) -> pd.DataFrame:
    df = df.reset_index(names="unique_id")
    values_col = df.columns[-1]
    df = pd.pivot(df, columns="unique_id", index="ds", values=values_col)
    return df

Issue Severity

High: It blocks me from completing my task.

marcopeix commented 4 days ago

Hello! The error occurs when you fit the model. Here, the combination of length of training set, validation size and horizon cause the training to fail. Basically, your horizon and validation size are likely too large for the amount of data in the training set.

You can try removing the validation set completely or reduce the horizon and that will fix your problem.

j-adamczyk commented 4 days ago

I can reduce the validation set, it's no problem. However, I still want to do long horizon predictions. I mean, "long" is not particularly long here, compared to the training data, but I can live with that. I can accept shorter horizons for direct forecast, it's not a problem. However, I have to implement autoregressive forecasting manually then, which is definitely not convenient.

For example, if the model allows forecasting with horizon 128, it's ok, that's quite long. But when I need a forecast of 500, I would have to manually do this in a loop: forecast, append to the input data, forecast, repeat, cut the outputs to the required length. This should be handled by the library underneath, or at least wrapped in a function.