Open luningsun opened 1 year ago
I trained a tempflow model and saved it using serialize, however, when I try to deserialize, I found the follow error:
File "gluon07_test.py", line 232, in predictor_deserialized = PyTorchPredictor.deserialize(Path("./save_model_gluon07_test/")) File "/home/luningsun/anaconda3/envs/gluon/lib/python3.7/site-packages/gluonts/torch/model/predictor.py", line 141, in deserialize prediction_net = load_json(fp.read()) File "/home/luningsun/anaconda3/envs/gluon/lib/python3.7/site-packages/gluonts/core/serde/_json.py", line 77, in load_json return decode(json.loads(s)) File "/home/luningsun/anaconda3/envs/gluon/lib/python3.7/site-packages/gluonts/core/serde/_base.py", line 295, in decode return cls(*args, **kwargs) TypeError: init() missing 1 required positional argument: 'num_parallel_samples'
FYI, the code I used to save and load model are as following:
predictor = estimator.train(dataset_train) from pathlib import Path predictor.serialize(Path("./save_model_gluon07_test/"))
predictor_deserialized = PyTorchPredictor.deserialize(Path("./save_model_gluon07_test/"))
I trained a tempflow model and saved it using serialize, however, when I try to deserialize, I found the follow error:
File "gluon07_test.py", line 232, in
predictor_deserialized = PyTorchPredictor.deserialize(Path("./save_model_gluon07_test/"))
File "/home/luningsun/anaconda3/envs/gluon/lib/python3.7/site-packages/gluonts/torch/model/predictor.py", line 141, in deserialize
prediction_net = load_json(fp.read())
File "/home/luningsun/anaconda3/envs/gluon/lib/python3.7/site-packages/gluonts/core/serde/_json.py", line 77, in load_json
return decode(json.loads(s))
File "/home/luningsun/anaconda3/envs/gluon/lib/python3.7/site-packages/gluonts/core/serde/_base.py", line 295, in decode
return cls(*args, **kwargs)
TypeError: init() missing 1 required positional argument: 'num_parallel_samples'
FYI, the code I used to save and load model are as following:
save model
predictor = estimator.train(dataset_train) from pathlib import Path predictor.serialize(Path("./save_model_gluon07_test/"))
loads it back
from gluonts.model.predictor import Predictor
predictor_deserialized = PyTorchPredictor.deserialize(Path("./save_model_gluon07_test/"))