Closed StatMixedML closed 4 years ago
I am trying to use DeepVAREstimator from the issue-3 branch throwing an error NameError: name 'SetField' is not defined.
from pts.transform import SetField device = torch.device("cuda" if torch.cuda.is_available() else "cpu") trainer = Trainer(device = device, epochs = 10) estimator = DeepVAREstimator(input_size = 401, freq = "1M", prediction_length = pred_h, context_length = pred_h*2, target_dim = target_dim, use_feat_static_cat = True, cardinality = card_static, trainer = trainer) predictor = estimator.train(training_data = train_ds)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) <ipython-input-27-375c015eb18b> in <module> 20 # time_features = feat_dynamic_real_train, 21 trainer = trainer) ---> 22 predictor = estimator.train(training_data = train_ds) 23 predictor.__dict__["prediction_net"] ~/miniconda3/envs/pytorchts/lib/python3.7/site-packages/pts/model/estimator.py in train(self, training_data) 132 133 def train(self, training_data: Dataset) -> Predictor: --> 134 return self.train_model(training_data).predictor ~/miniconda3/envs/pytorchts/lib/python3.7/site-packages/pts/model/estimator.py in train_model(self, training_data) 98 99 def train_model(self, training_data: Dataset) -> TrainOutput: --> 100 transformation = self.create_transformation() 101 transformation.estimate(iter(training_data)) 102 ~/miniconda3/envs/pytorchts/lib/python3.7/site-packages/pts/model/deepvar/deepvar_estimator.py in create_transformation(self) 154 else [] 155 ) --> 156 + [ 157 AsNumpyArray( 158 field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=np.long, NameError: name 'SetField' is not defined
Include the following into deepvar_estimator.py
from pts.transform import ( ... SetField )
@StatMixedML cool will fix it in a few mins. Thanks!
@kashif Any updates on this?
fixed by 8e9f31eae120d42f7e09ef738542a4d7e73f60e8
Description
I am trying to use DeepVAREstimator from the issue-3 branch throwing an error NameError: name 'SetField' is not defined.
To Reproduce
Error message output
Potential Solution
Include the following into deepvar_estimator.py