zalandoresearch / pytorch-ts

PyTorch based Probabilistic Time Series forecasting framework based on GluonTS backend
MIT License
1.21k stars 191 forks source link

Using TransMAF with other datasets #99

Open g-benton opened 2 years ago

g-benton commented 2 years ago

Hi All,

I'm having trouble figuring out how to run TransMAF with any datasets other than solar (shown in the example notebook). I've seen the advice on setting the input_size parameter before, which worked in other instances, however I'm getting a less informative error message when using TransMAF.

If I run the code attached below, all I get for an error is RuntimeError: mat1 dim 1 must match mat2 dim 0, which doesn't tell me what the input dimension ought to be for this dataset.

Does anybody know how to setup these models in a way that they can be used with other datasets?

Thanks!

import numpy as np
import pandas as pd

import torch

from gluonts.dataset.multivariate_grouper import MultivariateGrouper
from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset
from pts.model.tempflow import TempFlowEstimator
from pts.model.transformer_tempflow import TransformerTempFlowEstimator
from pts import Trainer
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.evaluation import MultivariateEvaluator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Prepare data set

dataset = get_dataset("traffic_nips", regenerate=False)

dataset.metadata

train_grouper = MultivariateGrouper(max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality))

test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test)/len(dataset.train)), 
                                   max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality))

dataset_train = train_grouper(dataset.train)
dataset_test = test_grouper(dataset.test)

## Evaluator

evaluator = MultivariateEvaluator(quantiles=(np.arange(20)/20.0)[1:],
                                  target_agg_funcs={'sum': np.sum})

## `Transformer-MAF`

estimator = TransformerTempFlowEstimator(
    d_model=16,
    num_heads=4,
    input_size=552,
    target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),
    prediction_length=dataset.metadata.prediction_length,
    context_length=dataset.metadata.prediction_length*4,
    flow_type='MAF',
    dequantize=True,
    freq=dataset.metadata.freq,
    trainer=Trainer(
        device=device,
        epochs=14,
#         epochs=1,
        learning_rate=1e-3,
        num_batches_per_epoch=100,
        batch_size=64,
    )
)

predictor = estimator.train(dataset_train)
kashif commented 2 years ago

right I believe in that error message should be the size that is the input_size you need to pass... can you paste the full error?

g-benton commented 2 years ago

Thank you for the fast reply! Here is the full error message

RuntimeError                              Traceback (most recent call last)
<ipython-input-1-aaf679a2dfb8> in <module>
     55 )
     56 
---> 57 predictor = estimator.train(dataset_train)

~/pytorch-ts/pts/model/estimator.py in train(self, training_data, validation_data, num_workers, prefetch_factor, shuffle_buffer_length, cache_data, **kwargs)
    177         **kwargs,
    178     ) -> PyTorchPredictor:
--> 179         return self.train_model(
    180             training_data,
    181             validation_data,

~/pytorch-ts/pts/model/estimator.py in train_model(self, training_data, validation_data, num_workers, prefetch_factor, shuffle_buffer_length, cache_data, **kwargs)
    149             )
    150 
--> 151         self.trainer(
    152             net=trained_net,
    153             train_iter=training_data_loader,

~/pytorch-ts/pts/trainer.py in __call__(self, net, train_iter, validation_iter)
     65 
     66                     inputs = [v.to(self.device) for v in data_entry.values()]
---> 67                     output = net(*inputs)
     68 
     69                     if isinstance(output, (list, tuple)):

~/miniconda3/envs/rpp/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/pytorch-ts/pts/model/transformer_tempflow/transformer_tempflow_network.py in forward(self, target_dimension_indicator, past_time_feat, past_target_cdf, past_observed_values, past_is_pad, future_time_feat, future_target_cdf, future_observed_values)
    352         print(self.transformer.encoder)
    353         enc_out = self.transformer.encoder(
--> 354             self.encoder_input(enc_inputs).permute(1, 0, 2)
    355         )
    356 

~/miniconda3/envs/rpp/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/miniconda3/envs/rpp/lib/python3.8/site-packages/torch/nn/modules/linear.py in forward(self, input)
     92 
     93     def forward(self, input: Tensor) -> Tensor:
---> 94         return F.linear(input, self.weight, self.bias)
     95 
     96     def extra_repr(self) -> str:

~/miniconda3/envs/rpp/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1751     if has_torch_function_variadic(input, weight):
   1752         return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
-> 1753     return torch._C._nn.linear(input, weight, bias)
   1754 
   1755 

RuntimeError: mat1 dim 1 must match mat2 dim 0
kashif commented 2 years ago

ah damn i remember in the new version of pytorch it doesn't print the actual dims... let me check my notebook

kashif commented 2 years ago

can you kindly try:

estimator = TempFlowEstimator(
    target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),
    prediction_length=dataset.metadata.prediction_length,
    cell_type='GRU',
    input_size=3856,
    num_cells=128,
    freq=dataset.metadata.freq,
    scaling=True,
    n_blocks=4,
    dropout_rate=0.3,
    trainer=Trainer(device=device,
                    epochs=21,
                    learning_rate=1e-3,
                    num_batches_per_epoch=100,
                    batch_size=32,
                    num_workers=8)
)
g-benton commented 2 years ago

That works - thanks. Is the best solution to this problem really to just print out the inputs to the forwards pass of the model then use that as the input dimension?

kashif commented 2 years ago

the ideal solution is to calculate the data cov. feature sizes as well as the multivariate dim and the amount of lag features etc. and then calculate it... I believe the new gluont-ts estimator API does that... I haven't gotten around to update this...

hanlaoshi commented 1 year ago

Right, I have reproduced the code and found that the input_size=taget_dim+target_dim*len(lags_seq)+len(time_feat). For example, in the solar dataset, target=137, len(lags_seq)=3, and len(time_feat)=4, so the input_size=137+137x3+4=552. It's the same for traffic_nips. Good luck!