zalandoresearch / pytorch-ts

PyTorch based Probabilistic Time Series forecasting framework based on GluonTS backend
MIT License
1.24k stars 191 forks source link

Seconds frequency not supported #90

Open babaksit opened 2 years ago

babaksit commented 2 years ago

Hello,

How can I add support for seconds frequency?

when I run the following code:

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  estimator = DeepAREstimator(freq="1s",
                            prediction_length=12,
                            input_size=19,
                            trainer=Trainer(epochs=10,
                                            device=device))
  predictor = estimator.train(training_data=training_data)

I got the following error:

---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<ipython-input-10-7c9b3c3e6ba0> in <module>
      1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
      2 
----> 3 estimator = DeepAREstimator(freq="1s",
      4                             prediction_length=12,
      5                             input_size=19,

~/.conda/envs/thesis/lib/python3.8/site-packages/gluonts/core/component.py in init_wrapper(*args, **kwargs)
    343                 self.__class__.__repr__ = validated_repr
    344 
--> 345             return init(self, **all_args)
    346 
    347         # attach the Pydantic model as the attribute of the initializer wrapper

~/.conda/envs/thesis/lib/python3.8/site-packages/pts/model/deepar/deepar_estimator.py in __init__(self, freq, prediction_length, input_size, trainer, context_length, num_layers, num_cells, cell_type, dropout_rate, use_feat_dynamic_real, use_feat_dynamic_cat, use_feat_static_cat, use_feat_static_real, cardinality, embedding_dimension, distr_output, scaling, lags_seq, time_features, num_parallel_samples, dtype)
     94             lags_seq
     95             if lags_seq is not None
---> 96             else get_lags_for_frequency(freq_str=freq, lag_ub=self.context_length)
     97         )
     98         self.time_features = (

~/.conda/envs/thesis/lib/python3.8/site-packages/gluonts/time_feature/lag.py in get_lags_for_frequency(freq_str, lag_ub, num_lags)
    121         )
    122     else:
--> 123         raise Exception("invalid frequency")
    124 
    125     # flatten lags list and filter

Exception: invalid frequency
kashif commented 2 years ago

@babaksit yes you can add support for it by making a PR to this https://github.com/awslabs/gluon-ts/blob/master/src/gluonts/time_feature/_base.py#L170

babaksit commented 2 years ago

@kashif Thank you for your prompt answer, I will make a try then, I hope it works in the end :)