awslabs / gluonts

Probabilistic time series modeling in Python
https://ts.gluon.ai
Apache License 2.0
4.55k stars 748 forks source link

Torch DeepAR implementation crashes if not supplied with static cat variables #2551

Closed pbruneau closed 1 year ago

pbruneau commented 1 year ago

Description

I have been trying to adapt the PyTorch tutorial example (https://ts.gluon.ai/stable/tutorials/advanced_topics/howto_pytorch_lightning.html) to using DeepAR instead of SimpleFeedForward. This tutorial example does not involve static or dynamic covariates. But the Torch DeepAR seems to require such covariates and crashes otherwise (when it is fine to go without covariates with the MXNet DeepAR).

Did I get something wrong, and/or is there a workaround to this issue? I would think about adding a dummy static cat feature, but maybe someone has something more elegant to suggest.

To Reproduce

https://gist.github.com/pbruneau/164dbe40b994185ea722aa80d27fae6c

Specifically, as num_feat* arguments are mandatory in the torch.model.deepar.DeepARModel interface, I thought the most sensible way to go would be:

    num_feat_dynamic_real=0,
    num_feat_static_real=0,
    num_feat_static_cat=0,
    cardinality=[],

Error message or code output

Traceback (most recent call last):
  File "minimal_example.py", line 68, in <module>
    trainer.fit(module, data_loader)
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 604, in fit
    self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1083, in _run
    self._call_callback_hooks("on_fit_start")
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1380, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_summary.py", line 59, in on_fit_start
    model_summary = self._summary(trainer, pl_module)
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_summary.py", line 73, in _summary
    return summarize(pl_module, max_depth=self._max_depth)
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py", line 431, in summarize
    return ModelSummary(lightning_module, max_depth=max_depth)
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py", line 189, in __init__
    self._layer_summary = self.summarize()
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py", line 246, in summarize
    self._forward_example_input()
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py", line 274, in _forward_example_input
    model(*input_)
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/files/postprocessing/3rdparty/gluon-ts/src/gluonts/torch/model/deepar/lightning_module.py", line 72, in forward
    return self.model(*args, **kwargs)
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/files/postprocessing/3rdparty/gluon-ts/src/gluonts/torch/model/deepar/module.py", line 341, in forward
    future_time_feat[:, :1],
  File "/files/postprocessing/3rdparty/gluon-ts/src/gluonts/torch/model/deepar/module.py", line 238, in unroll_lagged_rnn
    embedded_cat = self.embedder(feat_static_cat)
  File "/opt/conda/envs/rapids/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/files/postprocessing/3rdparty/gluon-ts/src/gluonts/torch/modules/feature.py", line 50, in forward
    dim=-1,
RuntimeError: torch.cat(): expected a non-empty list of Tensors

Environment

lostella commented 1 year ago

Hi @pbruneau thanks for filing this! In fact, as it is currently implemented, the model does require all the inputs. So those settings should really be asserted for being > 0 (and of the appropriate length in case of cardinality).

when it is fine to go without covariates with the MXNet DeepAR

Actually, the situation for MXNet DeepAR is exactly the same: the estimator can do without those features, and just make them up with sum dummy values, when producing batches, for the network to use; but the network does require all the inputs. Similarly, if you use the gluonts.torch.model.deepar.DeepAREstimator you'll be able to just not provide all the features, but interacting directly with the DeepARModel (i.e. the underlying network) requires them.

There might be the opportunity of using Optional inputs to the DeepARModel.forward method, I think we would need to explore that if we want to make the class easier to use on its own.

lostella commented 1 year ago

So to summarize, for the time being you can set those to 1, and cardinality = [1], and provide some dummy inputs (e.g. all zeros) in the associated arguments to forward. To facilitate this, if model is your DeepARModel object, consider using

model.input_shapes(batch_size=...)
model.input_types()

to see exactly how a batch of data is supposed to be structured in terms of shapes and types.

pbruneau commented 1 year ago

OK I followed your hint and tried to directly use gluonts.torch.model.deepar.DeepAREstimator, in the end I came up with this in place of l23-l68:

model = DeepAREstimator(
    prediction_length=prediction_length,
    context_length=context_length,
    distr_output=StudentTOutput(),
    freq='1H',
    num_feat_dynamic_real=0,
    num_feat_static_real=0,
    num_feat_static_cat=0,
    train_sampler = ExpectedNumInstanceSampler(
        num_instances=1,
        min_future=prediction_length,
        min_past=context_length,
    ),
    batch_size=32,
    num_batches_per_epoch=50,
    trainer_kwargs={
        'max_epochs': 10, 
        'gpus': -1 if torch.cuda.is_available() else None,
    },
)

predictor = model.train(dataset.train, shuffle_buffer_length=10, num_workers=4)

I guess I have no reason to use DeepARModel directly, so I'm going to close the issue :)

Maybe just one suggestion: it could be nice to update the tutorial at https://ts.gluon.ai/stable/tutorials/advanced_topics/howto_pytorch_lightning.html to include an example which uses a subclass of gluonts.torch.model.estimator.PyTorchLightningEstimator (such as gluonts.torch.model.deepar.DeepAREstimator or gluonts.torch.model.simple_feedforward.SimpleFeedForwardEstimator)!

lostella commented 1 year ago

Maybe just one suggestion: it could be nice to update the tutorial at https://ts.gluon.ai/stable/tutorials/advanced_topics/howto_pytorch_lightning.html to include an example which uses a subclass of gluonts.torch.model.estimator.PyTorchLightningEstimator (such as gluonts.torch.model.deepar.DeepAREstimator or gluonts.torch.model.simple_feedforward.SimpleFeedForwardEstimator)!

Right, good point; another thing is to validate all the settings for DeepARModel (e.g. that they are positive).