jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.87k stars 611 forks source link

Does TFT can fit univariate signal without any static\dynamic covariates? #458

Open harashm opened 3 years ago

harashm commented 3 years ago

I tried to fit TFT model with univariate signal, no covariates at all, and I got into the following error. I wonder if this model can basically be fitted to my scenario, because after digging into the code, I saw some 'if else' statements based on 'num_inputs > 1', therefore I suspect its a bug ...

Code to reproduce the problem

t = TimeSeriesDataSet(
    df,
    time_idx="time_idx",
    target='value',
    group_ids=['series'],
    max_encoder_length=100,
    max_prediction_length=10,
    time_varying_unknown_reals=['value'],
)
Traceback (most recent call last):
  File "C:/MyRepo/algorithms/auto_prediction/pytorch_forcasting_examples/nbeats_ar.py", line 265, in <module>
    main()
  File "C:/MyRepo/algorithms/auto_prediction/pytorch_forcasting_examples/nbeats_ar.py", line 230, in main
    val_dataloaders=val_dataloader,
  File "C:/MyRepo\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 499, in fit
    self.dispatch()
  File "C:/MyRepo\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 546, in dispatch
    self.accelerator.start_training(self)
  File "C:/MyRepo\venv\lib\site-packages\pytorch_lightning\accelerators\accelerator.py", line 73, in start_training
    self.training_type_plugin.start_training(trainer)
  File "C:/MyRepo\venv\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py", line 114, in start_training
    self._results = trainer.run_train()
  File "C:/MyRepo\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 607, in run_train
    self.run_sanity_check(self.lightning_module)
  File "C:/MyRepo\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 860, in run_sanity_check
    _, eval_results = self.run_evaluation(max_batches=self.num_sanity_val_batches)
  File "C:/MyRepo\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 725, in run_evaluation
    output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
  File "C:/MyRepo\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 166, in evaluation_step
    output = self.trainer.accelerator.validation_step(args)
  File "C:/MyRepo\venv\lib\site-packages\pytorch_lightning\accelerators\accelerator.py", line 177, in validation_step
    return self.training_type_plugin.validation_step(*args)
  File "C:/MyRepo\venv\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py", line 131, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "C:/MyRepo\venv\lib\site-packages\pytorch_forecasting\models\base_model.py", line 348, in validation_step
    log, _ = self.step(x, y, batch_idx)  # log loss
  File "C:/MyRepo\venv\lib\site-packages\pytorch_forecasting\models\temporal_fusion_transformer\__init__.py", line 523, in step
    log, out = super().step(x, y, batch_idx)
  File "C:/MyRepo\venv\lib\site-packages\pytorch_forecasting\models\base_model.py", line 438, in step
    out = self(x, **kwargs)
  File "C:/MyRepo\venv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:/MyRepo\venv\lib\site-packages\pytorch_forecasting\models\temporal_fusion_transformer\__init__.py", line 440, in forward
    static_context_variable_selection[:, max_encoder_length:],
  File "C:/MyRepo\venv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:/MyRepo\venv\lib\site-packages\pytorch_forecasting\models\temporal_fusion_transformer\sub_modules.py", line 341, in forward
    name = next(iter(self.single_variable_grns.keys()))
StopIteration
jdb78 commented 3 years ago

You need at least the time_idx as the known covariate in the future.

kevinfindlay commented 2 years ago

I had this bug as well. It results in an AttributeError

"AttributeError (note: full exception trace is shown but execution is paused at: _run_module_as_main) 'NoneType' object has no attribute 'item'" in the on_train_batch_end (line 331). This is caused by the output being None and it trying to access an item that does not exist.

It is caused by a StopIteration being issued by the Lightning Core because there is no item in the output.

This in turn causes the TemporalFusionTransformer to bomb out at line 438 from the forward method and never get to the LSTM.

    "embeddings_varying_decoder, decoder_sparse_weights = self.decoder_variable_selection(
        embeddings_varying_decoder,
        static_context_variable_selection[:, max_encoder_length:],
    )"

The root cause of this appears to be the sub_modules VariableSelectionNetwork forward function. This test for the self.num_inputs being greater than 1. However in this situation the self.num_inputs = 0 ( for some reason) and we end up with a "stopiteration" from the lower level context libraries.

This was resolved when I put in a "time_varying_known_reals" to compliment the single "time_varying_unknown_reals" that I had. I therefore had 2 input variables rather than 1.

From the documentation it does not say that you have to have more than 1 input variable

https://pytorch-forecasting.readthedocs.io/en/v0.2.2/api/pytorch_forecasting.data.TimeSeriesDataSet.html

The solution will be to either:

  1. Put an explicit assert in to prevent this happening
  2. Or deal with the self.num_inputs = 0 situation or debug how we got into this state.

FYI wasted 2 days of debugging into the code and will likely catch out newbies like me.

kevinfindlay commented 2 years ago

This is the solution I found This is the solution I found training = TimeSeriesDataSet( data, time_idx = 'dateIndex', target = 'variable1', group_ids=['groupID'], max_encoder_length = 10, min_encoder_length = 10, min_prediction_idx = 1, max_prediction_length=5, min_prediction_length=5, time_varying_unknown_reals=['variable1'],

time_varying_known_reals=['groupID'], - coommenting this in solves the problem

)