zalandoresearch / pytorch-ts

PyTorch based Probabilistic Time Series forecasting framework based on GluonTS backend
MIT License
1.24k stars 191 forks source link

Error load pytorchts model #128

Open luningsun opened 1 year ago

luningsun commented 1 year ago

I trained a tempflow model and saved it using serialize, however, when I try to deserialize, I found the follow error:

File "gluon07_test.py", line 232, in predictor_deserialized = PyTorchPredictor.deserialize(Path("./save_model_gluon07_test/")) File "/home/luningsun/anaconda3/envs/gluon/lib/python3.7/site-packages/gluonts/torch/model/predictor.py", line 141, in deserialize prediction_net = load_json(fp.read()) File "/home/luningsun/anaconda3/envs/gluon/lib/python3.7/site-packages/gluonts/core/serde/_json.py", line 77, in load_json return decode(json.loads(s)) File "/home/luningsun/anaconda3/envs/gluon/lib/python3.7/site-packages/gluonts/core/serde/_base.py", line 295, in decode return cls(*args, **kwargs) TypeError: init() missing 1 required positional argument: 'num_parallel_samples'

FYI, the code I used to save and load model are as following:

save model

predictor = estimator.train(dataset_train) from pathlib import Path predictor.serialize(Path("./save_model_gluon07_test/"))

loads it back

from gluonts.model.predictor import Predictor

predictor_deserialized = PyTorchPredictor.deserialize(Path("./save_model_gluon07_test/"))