Open Xinyu-Wu-0000 opened 1 month ago
no error when max_prediction_length=1
max_prediction_length=1
error with TFT, NHiTS, and DeepAR models.
Colab note book: https://colab.research.google.com/drive/16ft4Prqe1pEmLHgz1CcFnTZf4kA9kFX5?usp=sharing
import numpy as np import pandas as pd from lightning.pytorch import Trainer from pytorch_forecasting import ( DeepAR, NBeats, NHiTS, TemporalFusionTransformer, TimeSeriesDataSet, ) # model to test model_class = NHiTS model_class = TemporalFusionTransformer model_class = DeepAR #model_class = NBeats # set the max_prediction_length parameter max_prediction_length = 1 n_timeseries = 100 time_points = 100 data = pd.DataFrame( data={ "target": np.random.rand(time_points * n_timeseries), "time_varying_known_real_1": np.random.rand(time_points * n_timeseries), "time_idx": np.tile(np.arange(time_points), n_timeseries), "group_id": np.repeat(np.arange(n_timeseries), time_points), } ) print(data) training_dataset = TimeSeriesDataSet( data=data, time_idx="time_idx", target="target", group_ids=["group_id"], time_varying_unknown_reals=["target"], time_varying_known_reals=( ["time_varying_known_real_1"] if model_class != NBeats else [] ), max_prediction_length=max_prediction_length, ) validation_dataset = TimeSeriesDataSet.from_dataset( training_dataset, data, stop_randomization=True, predict=True ) training_data_loader = training_dataset.to_dataloader(train=True) validation_data_loader = validation_dataset.to_dataloader(train=False) forecaster = model_class.from_dataset(training_dataset, log_val_interval=1) pytorch_trainer = Trainer( accelerator="cpu", max_epochs=3, min_epochs=2, limit_train_batches=10, ) pytorch_trainer.fit( forecaster, train_dataloaders=training_data_loader, val_dataloaders=validation_data_loader, )
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) [<ipython-input-4-20cc2e71b35b>](https://localhost:8080/#) in <cell line: 50>() 48 ) 49 ---> 50 pytorch_trainer.fit( 51 forecaster, 52 train_dataloaders=training_data_loader, 18 frames [/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 542 self.state.status = TrainerStatus.RUNNING 543 self.training = True --> 544 call._call_and_handle_interrupt( 545 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 546 ) [/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py](https://localhost:8080/#) in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs) 42 if trainer.strategy.launcher is not None: 43 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs) ---> 44 return trainer_fn(*args, **kwargs) 45 46 except _TunerExitException: [/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 578 model_connected=self.lightning_module is not None, 579 ) --> 580 self._run(model, ckpt_path=ckpt_path) 581 582 assert self.state.stopped [/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in _run(self, model, ckpt_path) 985 # RUN THE TRAINER 986 # ---------------------------- --> 987 results = self._run_stage() 988 989 # ---------------------------- [/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in _run_stage(self) 1029 if self.training: 1030 with isolate_rng(): -> 1031 self._run_sanity_check() 1032 with torch.autograd.set_detect_anomaly(self._detect_anomaly): 1033 self.fit_loop.run() [/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py](https://localhost:8080/#) in _run_sanity_check(self) 1058 1059 # run eval step -> 1060 val_loop.run() 1061 1062 call._call_callback_hooks(self, "on_sanity_check_end") [/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/utilities.py](https://localhost:8080/#) in _decorator(self, *args, **kwargs) 180 context_manager = torch.no_grad 181 with context_manager(): --> 182 return loop_run(self, *args, **kwargs) 183 184 return _decorator [/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/evaluation_loop.py](https://localhost:8080/#) in run(self) 133 self.batch_progress.is_last_batch = data_fetcher.done 134 # run step hooks --> 135 self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter) 136 except StopIteration: 137 # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support [/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/evaluation_loop.py](https://localhost:8080/#) in _evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter) 394 else (dataloader_iter,) 395 ) --> 396 output = call._call_strategy_hook(trainer, hook_name, *step_args) 397 398 self.batch_progress.increment_processed() [/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py](https://localhost:8080/#) in _call_strategy_hook(trainer, hook_name, *args, **kwargs) 307 308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"): --> 309 output = fn(*args, **kwargs) 310 311 # restore current_fx when nested context [/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/strategy.py](https://localhost:8080/#) in validation_step(self, *args, **kwargs) 410 if self.model != self.lightning_module: 411 return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs) --> 412 return self.lightning_module.validation_step(*args, **kwargs) 413 414 def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: [/usr/local/lib/python3.10/dist-packages/pytorch_forecasting/models/base_model.py](https://localhost:8080/#) in validation_step(self, batch, batch_idx) 629 x, y = batch 630 log, out = self.step(x, y, batch_idx) --> 631 log.update(self.create_log(x, y, out, batch_idx)) 632 self.validation_step_outputs.append(log) 633 return log [/usr/local/lib/python3.10/dist-packages/pytorch_forecasting/models/deepar/__init__.py](https://localhost:8080/#) in create_log(self, x, y, out, batch_idx) 337 def create_log(self, x, y, out, batch_idx): 338 n_samples = [self.hparams.n_validation_samples, self.hparams.n_plotting_samples][self.training] --> 339 log = super().create_log( 340 x, 341 y, [/usr/local/lib/python3.10/dist-packages/pytorch_forecasting/models/base_model.py](https://localhost:8080/#) in create_log(self, x, y, out, batch_idx, prediction_kwargs, quantiles_kwargs) 682 self.log_metrics(x, y, out, prediction_kwargs=prediction_kwargs) 683 if self.log_interval > 0: --> 684 self.log_prediction( 685 x, out, batch_idx, prediction_kwargs=prediction_kwargs, quantiles_kwargs=quantiles_kwargs 686 ) [/usr/local/lib/python3.10/dist-packages/pytorch_forecasting/models/base_model.py](https://localhost:8080/#) in log_prediction(self, x, out, batch_idx, **kwargs) 935 log_indices = [0] 936 for idx in log_indices: --> 937 fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs) 938 tag = f"{self.current_stage} prediction" 939 if self.training: [/usr/local/lib/python3.10/dist-packages/pytorch_forecasting/models/base_model.py](https://localhost:8080/#) in plot_prediction(self, x, out, idx, add_loss_to_title, show_future_observed, ax, quantiles_kwargs, prediction_kwargs) 2247 quantiles_kwargs.setdefault("use_metric", False) 2248 -> 2249 return super().plot_prediction( 2250 x=x, 2251 out=out, [/usr/local/lib/python3.10/dist-packages/pytorch_forecasting/models/base_model.py](https://localhost:8080/#) in plot_prediction(self, x, out, idx, add_loss_to_title, show_future_observed, ax, quantiles_kwargs, prediction_kwargs) 1049 else: 1050 quantiles = torch.tensor([[y_quantile[0, i]], [y_quantile[0, -i - 1]]]) -> 1051 ax.errorbar( 1052 x_pred, 1053 y[[-n_pred]], [/usr/local/lib/python3.10/dist-packages/matplotlib/__init__.py](https://localhost:8080/#) in inner(ax, data, *args, **kwargs) 1440 def inner(ax, *args, data=None, **kwargs): 1441 if data is None: -> 1442 return func(ax, *map(sanitize_sequence, args), **kwargs) 1443 1444 bound = new_sig.bind(ax, *args, **kwargs) [/usr/local/lib/python3.10/dist-packages/matplotlib/axes/_axes.py](https://localhost:8080/#) in errorbar(self, x, y, yerr, xerr, fmt, ecolor, elinewidth, capsize, barsabove, lolims, uplims, xlolims, xuplims, errorevery, capthick, **kwargs) 3640 if np.any(np.less(err, -err, out=res, where=(err == err))): 3641 # like err<0, but also works for timedelta and nan. -> 3642 raise ValueError( 3643 f"'{dep_axis}err' must not contain negative values") 3644 # This is like ValueError: 'yerr' must not contain negative values
Looking forward to updating the new version!
Expected behavior
no error when
max_prediction_length=1
Actual behavior
error with TFT, NHiTS, and DeepAR models.
Code to reproduce the problem
Colab note book: https://colab.research.google.com/drive/16ft4Prqe1pEmLHgz1CcFnTZf4kA9kFX5?usp=sharing