jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.74k stars 599 forks source link

bug when `max_prediction_length=1` #1571

Open Xinyu-Wu-0000 opened 1 month ago

Xinyu-Wu-0000 commented 1 month ago

Expected behavior

no error when max_prediction_length=1

Actual behavior

error with TFT, NHiTS, and DeepAR models.

Code to reproduce the problem

Colab note book: https://colab.research.google.com/drive/16ft4Prqe1pEmLHgz1CcFnTZf4kA9kFX5?usp=sharing

import numpy as np
import pandas as pd
from lightning.pytorch import Trainer
from pytorch_forecasting import (
    DeepAR,
    NBeats,
    NHiTS,
    TemporalFusionTransformer,
    TimeSeriesDataSet,
  )

# model to test
model_class = NHiTS
model_class = TemporalFusionTransformer
model_class = DeepAR
#model_class = NBeats
# set the max_prediction_length parameter
max_prediction_length = 1

n_timeseries = 100
time_points = 100
data = pd.DataFrame(
    data={
        "target": np.random.rand(time_points * n_timeseries),
        "time_varying_known_real_1": np.random.rand(time_points * n_timeseries),
        "time_idx": np.tile(np.arange(time_points), n_timeseries),
        "group_id": np.repeat(np.arange(n_timeseries), time_points),
    }
)
print(data)

training_dataset = TimeSeriesDataSet(
    data=data,
    time_idx="time_idx",
    target="target",
    group_ids=["group_id"],
    time_varying_unknown_reals=["target"],
    time_varying_known_reals=(
        ["time_varying_known_real_1"] if model_class != NBeats else []
    ),
    max_prediction_length=max_prediction_length,
)

validation_dataset = TimeSeriesDataSet.from_dataset(
    training_dataset, data, stop_randomization=True, predict=True
)

training_data_loader = training_dataset.to_dataloader(train=True)

validation_data_loader = validation_dataset.to_dataloader(train=False)

forecaster = model_class.from_dataset(training_dataset, log_val_interval=1)

pytorch_trainer = Trainer(
    accelerator="cpu",
    max_epochs=3,
    min_epochs=2,
    limit_train_batches=10,
)

pytorch_trainer.fit(
    forecaster,
    train_dataloaders=training_data_loader,
    val_dataloaders=validation_data_loader,
)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-4-20cc2e71b35b>](https://localhost:8080/#) in <cell line: 50>()
     48 )
     49 
---> 50 pytorch_trainer.fit(
     51     forecaster,
     52     train_dataloaders=training_data_loader,

18 frames
[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    542         self.state.status = TrainerStatus.RUNNING
    543         self.training = True
--> 544         call._call_and_handle_interrupt(
    545             self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    546         )

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py](https://localhost:8080/#) in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42         if trainer.strategy.launcher is not None:
     43             return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44         return trainer_fn(*args, **kwargs)
     45 
     46     except _TunerExitException:

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    578             model_connected=self.lightning_module is not None,
    579         )
--> 580         self._run(model, ckpt_path=ckpt_path)
    581 
    582         assert self.state.stopped

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in _run(self, model, ckpt_path)
    985         # RUN THE TRAINER
    986         # ----------------------------
--> 987         results = self._run_stage()
    988 
    989         # ----------------------------

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in _run_stage(self)
   1029         if self.training:
   1030             with isolate_rng():
-> 1031                 self._run_sanity_check()
   1032             with torch.autograd.set_detect_anomaly(self._detect_anomaly):
   1033                 self.fit_loop.run()

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in _run_sanity_check(self)
   1058 
   1059             # run eval step
-> 1060             val_loop.run()
   1061 
   1062             call._call_callback_hooks(self, "on_sanity_check_end")

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/utilities.py](https://localhost:8080/#) in _decorator(self, *args, **kwargs)
    180             context_manager = torch.no_grad
    181         with context_manager():
--> 182             return loop_run(self, *args, **kwargs)
    183 
    184     return _decorator

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/evaluation_loop.py](https://localhost:8080/#) in run(self)
    133                 self.batch_progress.is_last_batch = data_fetcher.done
    134                 # run step hooks
--> 135                 self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
    136             except StopIteration:
    137                 # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/evaluation_loop.py](https://localhost:8080/#) in _evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
    394             else (dataloader_iter,)
    395         )
--> 396         output = call._call_strategy_hook(trainer, hook_name, *step_args)
    397 
    398         self.batch_progress.increment_processed()

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py](https://localhost:8080/#) in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    307 
    308     with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 309         output = fn(*args, **kwargs)
    310 
    311     # restore current_fx when nested context

[/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/strategy.py](https://localhost:8080/#) in validation_step(self, *args, **kwargs)
    410             if self.model != self.lightning_module:
    411                 return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 412             return self.lightning_module.validation_step(*args, **kwargs)
    413 
    414     def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:

[/usr/local/lib/python3.10/dist-packages/pytorch_forecasting/models/base_model.py](https://localhost:8080/#) in validation_step(self, batch, batch_idx)
    629         x, y = batch
    630         log, out = self.step(x, y, batch_idx)
--> 631         log.update(self.create_log(x, y, out, batch_idx))
    632         self.validation_step_outputs.append(log)
    633         return log

[/usr/local/lib/python3.10/dist-packages/pytorch_forecasting/models/deepar/__init__.py](https://localhost:8080/#) in create_log(self, x, y, out, batch_idx)
    337     def create_log(self, x, y, out, batch_idx):
    338         n_samples = [self.hparams.n_validation_samples, self.hparams.n_plotting_samples][self.training]
--> 339         log = super().create_log(
    340             x,
    341             y,

[/usr/local/lib/python3.10/dist-packages/pytorch_forecasting/models/base_model.py](https://localhost:8080/#) in create_log(self, x, y, out, batch_idx, prediction_kwargs, quantiles_kwargs)
    682         self.log_metrics(x, y, out, prediction_kwargs=prediction_kwargs)
    683         if self.log_interval > 0:
--> 684             self.log_prediction(
    685                 x, out, batch_idx, prediction_kwargs=prediction_kwargs, quantiles_kwargs=quantiles_kwargs
    686             )

[/usr/local/lib/python3.10/dist-packages/pytorch_forecasting/models/base_model.py](https://localhost:8080/#) in log_prediction(self, x, out, batch_idx, **kwargs)
    935                 log_indices = [0]
    936             for idx in log_indices:
--> 937                 fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs)
    938                 tag = f"{self.current_stage} prediction"
    939                 if self.training:

[/usr/local/lib/python3.10/dist-packages/pytorch_forecasting/models/base_model.py](https://localhost:8080/#) in plot_prediction(self, x, out, idx, add_loss_to_title, show_future_observed, ax, quantiles_kwargs, prediction_kwargs)
   2247             quantiles_kwargs.setdefault("use_metric", False)
   2248 
-> 2249         return super().plot_prediction(
   2250             x=x,
   2251             out=out,

[/usr/local/lib/python3.10/dist-packages/pytorch_forecasting/models/base_model.py](https://localhost:8080/#) in plot_prediction(self, x, out, idx, add_loss_to_title, show_future_observed, ax, quantiles_kwargs, prediction_kwargs)
   1049                 else:
   1050                     quantiles = torch.tensor([[y_quantile[0, i]], [y_quantile[0, -i - 1]]])
-> 1051                     ax.errorbar(
   1052                         x_pred,
   1053                         y[[-n_pred]],

[/usr/local/lib/python3.10/dist-packages/matplotlib/__init__.py](https://localhost:8080/#) in inner(ax, data, *args, **kwargs)
   1440     def inner(ax, *args, data=None, **kwargs):
   1441         if data is None:
-> 1442             return func(ax, *map(sanitize_sequence, args), **kwargs)
   1443 
   1444         bound = new_sig.bind(ax, *args, **kwargs)

[/usr/local/lib/python3.10/dist-packages/matplotlib/axes/_axes.py](https://localhost:8080/#) in errorbar(self, x, y, yerr, xerr, fmt, ecolor, elinewidth, capsize, barsabove, lolims, uplims, xlolims, xuplims, errorevery, capthick, **kwargs)
   3640             if np.any(np.less(err, -err, out=res, where=(err == err))):
   3641                 # like err<0, but also works for timedelta and nan.
-> 3642                 raise ValueError(
   3643                     f"'{dep_axis}err' must not contain negative values")
   3644             # This is like

ValueError: 'yerr' must not contain negative values
freefish1218 commented 6 days ago

Looking forward to updating the new version!