unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.92k stars 859 forks source link

[BUG] Spearmanrank metric raises error #1378

Closed Kayne88 closed 1 year ago

Kayne88 commented 1 year ago

Describe the bug Evaluation of spearmanr coef is failing due to shape mismatch. https://torchmetrics.readthedocs.io/en/latest/regression/spearman_corr_coef.html

Also, I am not sure, which key word to provide in EarlyStopping.monitor, when I want to monitor the provided torch metric in RNNModel.

To Reproduce

target_ts = TimeSeries.from_series(pd.Series(np.random.random(1000)), freq="1H")
covariates = TimeSeries.from_dataframe(pd.DataFrame(np.random.random((1000, 10))), freq="1H")

early_stopper = EarlyStopping(
    monitor="val_metric",
    patience=10,
    min_delta=0.005,
    mode='max',
)

pl_trainer_kwargs={"callbacks": [early_stopper], "accelerator": "gpu"}

model = RNNModel(
    input_chunk_length=5,
    model="LSTM",
    hidden_dim=25,
    n_rnn_layers=2,
    dropout=0.1,
    training_length=24,
    torch_metrics=SpearmanCorrCoef(),
    pl_trainer_kwargs=pl_trainer_kwargs
)

model.historical_forecasts(
    target_ts,
    past_covariates=covariates,
    forecast_horizon=24*7,
    start=0.8,
    verbose=True
)

Expected behavior It is expected that the evaluation of the torch metric is working.

System (please complete the following information):

Additional context Full stack trace

0%|          | 0/34 [00:00<?, ?it/s]INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type             | Params
---------------------------------------------------
0 | criterion     | MSELoss          | 0     
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
3 | rnn           | LSTM             | 8.0 K 
4 | V             | Linear           | 26    
---------------------------------------------------
8.0 K     Trainable params
0         Non-trainable params
8.0 K     Total params
0.064     Total estimated model params size (MB)
Epoch 0: 0%
0/25 [00:00<?, ?it/s]
  0%|          | 0/34 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-66-2fe86ef4aac5>](https://localhost:8080/#) in <module>
      5     retrain=24*7*4,
      6     start=0.8,
----> 7     verbose=True
      8 )

49 frames
[/usr/local/lib/python3.7/dist-packages/darts/utils/utils.py](https://localhost:8080/#) in sanitized_method(self, *args, **kwargs)
    170 
    171                 getattr(self, sanity_check_method)(*only_args.values(), **only_kwargs)
--> 172             return method_to_sanitize(self, *only_args.values(), **only_kwargs)
    173 
    174         return sanitized_method

[/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/forecasting_model.py](https://localhost:8080/#) in historical_forecasts(self, series, past_covariates, future_covariates, num_samples, train_length, start, forecast_horizon, stride, retrain, overlap_end, last_points_only, verbose)
    495                     series=train,
    496                     past_covariates=past_covariates,
--> 497                     future_covariates=future_covariates,
    498                 )
    499 

[/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/forecasting_model.py](https://localhost:8080/#) in _fit_wrapper(self, series, past_covariates, future_covariates)
   1245             future_covariates=future_covariates
   1246             if self.uses_future_covariates
-> 1247             else None,
   1248         )
   1249 

[/usr/local/lib/python3.7/dist-packages/darts/utils/torch.py](https://localhost:8080/#) in decorator(self, *args, **kwargs)
    110         with fork_rng():
    111             manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
--> 112             return decorated(self, *args, **kwargs)
    113 
    114     return decorator

[/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/torch_forecasting_model.py](https://localhost:8080/#) in fit(self, series, past_covariates, future_covariates, val_series, val_past_covariates, val_future_covariates, trainer, verbose, epochs, max_samples_per_ts, num_loader_workers)
    738 
    739         return self.fit_from_dataset(
--> 740             train_dataset, val_dataset, trainer, verbose, epochs, num_loader_workers
    741         )
    742 

[/usr/local/lib/python3.7/dist-packages/darts/utils/torch.py](https://localhost:8080/#) in decorator(self, *args, **kwargs)
    110         with fork_rng():
    111             manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
--> 112             return decorated(self, *args, **kwargs)
    113 
    114     return decorator

[/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/torch_forecasting_model.py](https://localhost:8080/#) in fit_from_dataset(self, train_dataset, val_dataset, trainer, verbose, epochs, num_loader_workers)
    892 
    893         # Train model
--> 894         self._train(train_loader, val_loader)
    895         return self
    896 

[/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/torch_forecasting_model.py](https://localhost:8080/#) in _train(self, train_loader, val_loader)
    918             train_dataloaders=train_loader,
    919             val_dataloaders=val_loader,
--> 920             ckpt_path=ckpt_path,
    921         )
    922 

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    581         self.strategy._lightning_module = model
    582         call._call_and_handle_interrupt(
--> 583             self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    584         )
    585 

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/call.py](https://localhost:8080/#) in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     36             return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     37         else:
---> 38             return trainer_fn(*args, **kwargs)
     39 
     40     except _TunerExitException:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    622             model_connected=self.lightning_module is not None,
    623         )
--> 624         self._run(model, ckpt_path=self.ckpt_path)
    625 
    626         assert self.state.stopped

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run(self, model, ckpt_path)
   1059         self._checkpoint_connector.resume_end()
   1060 
-> 1061         results = self._run_stage()
   1062 
   1063         log.detail(f"{self.__class__.__name__}: trainer tearing down")

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run_stage(self)
   1138         if self.predicting:
   1139             return self._run_predict()
-> 1140         self._run_train()
   1141 
   1142     def _pre_training_routine(self) -> None:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run_train(self)
   1161 
   1162         with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1163             self.fit_loop.run()
   1164 
   1165     def _run_evaluate(self) -> _EVALUATE_OUTPUT:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/loop.py](https://localhost:8080/#) in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/fit_loop.py](https://localhost:8080/#) in advance(self)
    265         self._data_fetcher.setup(dataloader, batch_to_device=batch_to_device)
    266         with self.trainer.profiler.profile("run_training_epoch"):
--> 267             self._outputs = self.epoch_loop.run(self._data_fetcher)
    268 
    269     def on_advance_end(self) -> None:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/loop.py](https://localhost:8080/#) in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py](https://localhost:8080/#) in advance(self, data_fetcher)
    212 
    213             with self.trainer.profiler.profile("run_training_batch"):
--> 214                 batch_output = self.batch_loop.run(kwargs)
    215 
    216         self.batch_progress.increment_processed()

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/loop.py](https://localhost:8080/#) in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/batch/training_batch_loop.py](https://localhost:8080/#) in advance(self, kwargs)
     86                 self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0)
     87             )
---> 88             outputs = self.optimizer_loop.run(optimizers, kwargs)
     89         else:
     90             outputs = self.manual_loop.run(kwargs)

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/loop.py](https://localhost:8080/#) in run(self, *args, **kwargs)
    197             try:
    198                 self.on_advance_start(*args, **kwargs)
--> 199                 self.advance(*args, **kwargs)
    200                 self.on_advance_end()
    201                 self._restarting = False

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in advance(self, optimizers, kwargs)
    198         kwargs = self._build_kwargs(kwargs, self.optimizer_idx, self._hiddens)
    199 
--> 200         result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
    201         if result.loss is not None:
    202             # automatic optimization assumes a loss needs to be returned for extras to be considered as the batch

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in _run_optimization(self, kwargs, optimizer)
    245         else:
    246             # the `batch_idx` is optional with inter-batch parallelism
--> 247             self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
    248 
    249         result = closure.consume_result()

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in _optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    364             on_tpu=isinstance(self.trainer.accelerator, TPUAccelerator),
    365             using_native_amp=(self.trainer.amp_backend == AMPType.NATIVE),
--> 366             using_lbfgs=is_lbfgs,
    367         )
    368 

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _call_lightning_module_hook(self, hook_name, pl_module, *args, **kwargs)
   1303 
   1304         with self.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
-> 1305             output = fn(*args, **kwargs)
   1306 
   1307         # restore current_fx when nested context

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/module.py](https://localhost:8080/#) in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1659 
   1660         """
-> 1661         optimizer.step(closure=optimizer_closure)
   1662 
   1663     def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int) -> None:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/core/optimizer.py](https://localhost:8080/#) in step(self, closure, **kwargs)
    167 
    168         assert self._strategy is not None
--> 169         step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
    170 
    171         self._on_after_step()

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/strategies/strategy.py](https://localhost:8080/#) in optimizer_step(self, optimizer, opt_idx, closure, model, **kwargs)
    233         assert isinstance(model, pl.LightningModule)
    234         return self.precision_plugin.optimizer_step(
--> 235             optimizer, model=model, optimizer_idx=opt_idx, closure=closure, **kwargs
    236         )
    237 

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py](https://localhost:8080/#) in optimizer_step(self, optimizer, model, optimizer_idx, closure, **kwargs)
    119         """Hook to run the optimizer step."""
    120         closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
--> 121         return optimizer.step(closure=closure, **kwargs)
    122 
    123     def _track_grad_norm(self, trainer: "pl.Trainer") -> None:

[/usr/local/lib/python3.7/dist-packages/torch/optim/optimizer.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    111                 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
    112                 with torch.autograd.profiler.record_function(profile_name):
--> 113                     return func(*args, **kwargs)
    114             return wrapper
    115 

[/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py](https://localhost:8080/#) in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.clone():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)
     29 

[/usr/local/lib/python3.7/dist-packages/torch/optim/adam.py](https://localhost:8080/#) in step(self, closure)
    116         if closure is not None:
    117             with torch.enable_grad():
--> 118                 loss = closure()
    119 
    120         for group in self.param_groups:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py](https://localhost:8080/#) in _wrap_closure(self, model, optimizer, optimizer_idx, closure)
    105         consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
    106         """
--> 107         closure_result = closure()
    108         self._after_closure(model, optimizer, optimizer_idx)
    109         return closure_result

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    145 
    146     def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 147         self._result = self.closure(*args, **kwargs)
    148         return self._result.loss
    149 

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in closure(self, *args, **kwargs)
    131 
    132     def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 133         step_output = self._step_fn()
    134 
    135         if step_output.closure_loss is None:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py](https://localhost:8080/#) in _training_step(self, kwargs)
    404         """
    405         # manually capture logged metrics
--> 406         training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
    407         self.trainer.strategy.post_training_step()
    408 

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _call_strategy_hook(self, hook_name, *args, **kwargs)
   1441 
   1442         with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1443             output = fn(*args, **kwargs)
   1444 
   1445         # restore current_fx when nested context

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/strategies/strategy.py](https://localhost:8080/#) in training_step(self, *args, **kwargs)
    376         with self.precision_plugin.train_step_context():
    377             assert isinstance(self.model, TrainingStep)
--> 378             return self.model.training_step(*args, **kwargs)
    379 
    380     def post_training_step(self) -> None:

[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/precision/double.py](https://localhost:8080/#) in training_step(self, *args, **kwargs)
     42         return self.module.training_step(
     43             *LightningDoublePrecisionModule._move_float_tensors_to_double(args),
---> 44             **LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
     45         )
     46 

[/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/pl_forecasting_module.py](https://localhost:8080/#) in training_step(self, train_batch, batch_idx)
    152         loss = self._compute_loss(output, target)
    153         self.log("train_loss", loss, batch_size=train_batch[0].shape[0], prog_bar=True)
--> 154         self._calculate_metrics(output, target, self.train_metrics)
    155         return loss
    156 

[/usr/local/lib/python3.7/dist-packages/darts/models/forecasting/pl_forecasting_module.py](https://localhost:8080/#) in _calculate_metrics(self, output, target, metrics)
    267             # If there's no likelihood, nr_params=1, and we need to squeeze out the
    268             # last dimension of model output, for properly computing the metric.
--> 269             _metric = metrics(target, output.squeeze(dim=-1))
    270 
    271         self.log_dict(

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.7/dist-packages/torchmetrics/collections.py](https://localhost:8080/#) in forward(self, *args, **kwargs)
    166         will be filtered based on the signature of the individual metric.
    167         """
--> 168         res = {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items(keep_base=True, copy_state=False)}
    169         res = _flatten_dict(res)
    170         return {self._set_name(k): v for k, v in res.items()}

[/usr/local/lib/python3.7/dist-packages/torchmetrics/collections.py](https://localhost:8080/#) in <dictcomp>(.0)
    166         will be filtered based on the signature of the individual metric.
    167         """
--> 168         res = {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items(keep_base=True, copy_state=False)}
    169         res = _flatten_dict(res)
    170         return {self._set_name(k): v for k, v in res.items()}

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.7/dist-packages/torchmetrics/metric.py](https://localhost:8080/#) in forward(self, *args, **kwargs)
    243             self._forward_cache = self._forward_full_state_update(*args, **kwargs)
    244         else:
--> 245             self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
    246 
    247         return self._forward_cache

[/usr/local/lib/python3.7/dist-packages/torchmetrics/metric.py](https://localhost:8080/#) in _forward_reduce_state_update(self, *args, **kwargs)
    307 
    308         # calculate batch state and compute batch value
--> 309         self.update(*args, **kwargs)
    310         batch_val = self.compute()
    311 

[/usr/local/lib/python3.7/dist-packages/torchmetrics/metric.py](https://localhost:8080/#) in wrapped_func(*args, **kwargs)
    393             with torch.set_grad_enabled(self._enable_grad):
    394                 try:
--> 395                     update(*args, **kwargs)
    396                 except RuntimeError as err:
    397                     if "Expected all tensors to be on" in str(err):

[/usr/local/lib/python3.7/dist-packages/torchmetrics/regression/spearman.py](https://localhost:8080/#) in update(self, preds, target)
     88             target: Ground truth values
     89         """
---> 90         preds, target = _spearman_corrcoef_update(preds, target, num_outputs=self.num_outputs)
     91         self.preds.append(preds)
     92         self.target.append(target)

[/usr/local/lib/python3.7/dist-packages/torchmetrics/functional/regression/spearman.py](https://localhost:8080/#) in _spearman_corrcoef_update(preds, target, num_outputs)
     70     if preds.ndim > 2 or target.ndim > 2:
     71         raise ValueError(
---> 72             f"Expected both predictions and target to be either 1- or 2-dimensional tensors,"
     73             f" but got {target.ndim} and {preds.ndim}."
     74         )

ValueError: Expected both predictions and target to be either 1- or 2-dimensional tensors, but got 3 and 3.
madtoinou commented 1 year ago

Hi!

After tinkering with the code, I think that I found a partial solution to your issue. As suggested by the exception, SpearmanCorrCoef expect its input to have either one or two dimensions. Due the internals of Darts, the shape of the tensors that are passed to this torchmetrics are both [32,24,1]. At the moment, it's not possible to interact with these dimensions through darts but you could define a wrapper to take care of the dimension reduction:

class WrapperSpearman(SpearmanCorrCoef):
    def update(self, preds, target):
        super().update(preds.squeeze(), target.squeeze())

After solving this dimension issue, a new exception is raised saying that the second dimension of the target should be identical to the parameters num_ouputs of the SpearmanCorrCoef. This problem can easily be solved by setting it to 24 when instantiating your metric.

Finally, the logger cannot accept a tensor of length 24 and suggest to call mean() on the output. Its possible by wrapping around the compute method:

    def compute(self):
        output = super().compute()
        return output.mean()

As for the EarlyStopping from PyTorch-Ignite, I think that you can get the PyTorch-Lightning trainer of the model by accessing the corresponding attribute model.trainer (which will instantiated at the first call of fit or predict) and pass it to the EarlyStopping object (as described in their example). I am not sure of how these modules will interact, if it does not work, I don't think that you could use PyTorch-Ignite trainer in darts models but I might be wrong...

Kayne88 commented 1 year ago

That indeed works! Thanks for your investigation. Where is that 24 parameter coming from? Is it the training length, i.e. batch of the targets?

I was also wondering, when using historical_forecast, how is the "rolling" training split into train and val? I would ideally want to monitor the val_metric WrapperSpearman, whenever the model is retrained.

PS: I am using early stopping from pytorch lightning - from pytorch_lightning.callbacks.early_stopping import EarlyStopping

madtoinou commented 1 year ago

It's indeed the value of the training_length argument of the FTF model.

The training split in historical_forecast is actually not "rolling" but "expanding", to cite the explanations in the method docstring: " it repeatedly builds a training set from the beginning of series. It trains the model on the training set, emits a forecast of length equal to forecast_horizon, and then moves the end of the training set forward by stride time steps".

TLDR; The training set grows incrementally as the historical_progress get closer to the last date.

For some model, this behavior can be disabled using the retrain argument. If it's set to False, the model will be trained only on the first fragment of the input time-serie (up to the startargument value) and the "expansion" will not occur.

I am not sure of where you could find the metric for each individual retrain: from my understanding of the source code, each model retrain will overwrite the metric obtained during the previous historical forecast since they share the same name/logger.

If you're interested in such feature, we can add it to the backlog and we would be very happy to have you contribute to the project by implementing it!

Kayne88 commented 1 year ago

Yes, you are indeed right, "expanding" is the default.

My question regarding the retraining was more, if for each retrain there is a train-val split? Because in early stopping I would like to monitor the val metric, also for the historical_forecasts method.

Regard contribution: I am pretty new to pytorch and pytorch lightning. I would honestly not consider myself proficient enough.

madtoinou commented 1 year ago

Oh sorry, I misunderstood your question.

There is apparently no training-validation split during the retraining in the historical_forecasting method but what you're looking for is the metric between the prediction of each "historical window" and the ground truth, right? Because from my small experiments, the validation loss/metric do not appear to be logged.

After verification, the reported values are indeed overwritten during model retraining and only the last one is available.

I'll add it to the backlog and start thinking about how to implement this feature.