zalandoresearch / pytorch-ts

PyTorch based Probabilistic Time Series forecasting framework based on GluonTS backend
MIT License
1.24k stars 191 forks source link

TempFlowEstimator deserialize error #30

Closed goomhow closed 3 years ago

goomhow commented 3 years ago

i train a tempFlow model,and then save(serialize) in a folder, but when i load(deserialize)the model,error append,just like

Traceback (most recent call last): File "C:\ProgramData\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 2961, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "", line 1, in p = Predictor.deserialize(get_model_path('tempflow')) File "C:\ProgramData\Anaconda3\lib\site-packages\pts\model\predictor.py", line 82, in deserialize return tpe.deserialize(path, device) File "C:\ProgramData\Anaconda3\lib\site-packages\pts\model\predictor.py", line 172, in deserialize transformation = load_json(fp.read()) File "C:\ProgramData\Anaconda3\lib\site-packages\pts\core\serde.py", line 173, in load_json return decode(json.loads(s)) File "C:\ProgramData\Anaconda3\lib\site-packages\pts\core\serde.py", line 354, in decode kwargs = decode(r["kwargs"]) if "kwargs" in r else {} File "C:\ProgramData\Anaconda3\lib\site-packages\pts\core\serde.py", line 362, in decode return {k: decode(v) for k, v in r.items()} File "C:\ProgramData\Anaconda3\lib\site-packages\pts\core\serde.py", line 362, in return {k: decode(v) for k, v in r.items()} File "C:\ProgramData\Anaconda3\lib\site-packages\pts\core\serde.py", line 368, in decode return [decode(y) for y in r] File "C:\ProgramData\Anaconda3\lib\site-packages\pts\core\serde.py", line 368, in return [decode(y) for y in r] File "C:\ProgramData\Anaconda3\lib\site-packages\pts\core\serde.py", line 354, in decode kwargs = decode(r["kwargs"]) if "kwargs" in r else {} File "C:\ProgramData\Anaconda3\lib\site-packages\pts\core\serde.py", line 362, in decode return {k: decode(v) for k, v in r.items()} File "C:\ProgramData\Anaconda3\lib\site-packages\pts\core\serde.py", line 362, in return {k: decode(v) for k, v in r.items()} File "C:\ProgramData\Anaconda3\lib\site-packages\pts\core\serde.py", line 368, in decode return [decode(y) for y in r] File "C:\ProgramData\Anaconda3\lib\site-packages\pts\core\serde.py", line 368, in return [decode(y) for y in r] File "C:\ProgramData\Anaconda3\lib\site-packages\pts\core\serde.py", line 355, in decode return cls(*args, **kwargs) TypeError: init() got an unexpected keyword argument 'normalized'

kashif commented 3 years ago

@goomhow its fixed in master, I will double check and make a new release which fixes this. Thanks!

goomhow commented 3 years ago

i also find a something strange,i`m try to train a deepAR model ,the same data and the same code, with gluonts run well,but run in torchts ,error shows

Traceback (most recent call last):
  File "ts.py", line 273, in <module>
    train_deep_ar()
  File "ts.py", line 234, in train_deep_ar
    predictor: Predictor = estimator.train(training_data=training_data)
  File "/home/gonghao/anaconda3/lib/python3.8/site-packages/pts/model/estimator.py", line 148, in train
    return self.train_model(training_data).predictor
  File "/home/gonghao/anaconda3/lib/python3.8/site-packages/pts/model/estimator.py", line 133, in train_model
    self.trainer(
  File "/home/gonghao/anaconda3/lib/python3.8/site-packages/pts/trainer.py", line 52, in __call__
    output = net(*inputs)
  File "/home/gonghao/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/gonghao/anaconda3/lib/python3.8/site-packages/pts/model/deepar/deepar_network.py", line 246, in forward
    distr = self.distribution(
  File "/home/gonghao/anaconda3/lib/python3.8/site-packages/pts/model/deepar/deepar_network.py", line 221, in distribution
    rnn_outputs, _, scale, _ = self.unroll_encoder(
  File "/home/gonghao/anaconda3/lib/python3.8/site-packages/pts/model/deepar/deepar_network.py", line 168, in unroll_encoder
    embedded_cat = self.embedder(feat_static_cat)
  File "/home/gonghao/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/gonghao/anaconda3/lib/python3.8/site-packages/pts/modules/feature.py", line 30, in forward
    [
  File "/home/gonghao/anaconda3/lib/python3.8/site-packages/pts/modules/feature.py", line 31, in <listcomp>
    embed(cat_feature_slice.squeeze(-1))
  File "/home/gonghao/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/gonghao/anaconda3/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 124, in forward
    return F.embedding(
  File "/home/gonghao/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 1852, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self

and when i set feat_static_cat=False,everything is ok.

kashif commented 3 years ago

@goomhow so this is an un-related issue to this one... and is most probably because the embedding layer's configuration is off-by-one, note that pytorch's embedding layer takes tokens from 0 till the number of cat -1, can you confirm that you are setting the carnality properly?

goomhow commented 3 years ago

i set the cardinality = [15, 301] which present categoryLevel1 number and categoryLevel2 number of our products,and by default, the embedding_dimension = [8, 50]. what bother me is the same code and data,can only run with gluonts-deepAR-model, and torchts-deepAR-model failed. thanks for your patience.

kashif commented 3 years ago

I see and if you increase the cardinality by 1 or say 10 for each does the issue still occur?

goomhow commented 3 years ago

i only use categoryLevel1 number as only feat_static_cat,and so cardinality should be [15],but i found when cardinality >= 8362 its run well...what a strange number

kashif commented 3 years ago

@goomhow somehow in your data there seems to be a token with id 8361 or so... can you check?

goomhow commented 3 years ago

oh。。。you are right, thank you very much,you are so cool,and in chinese ,你真牛逼