jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.74k stars 599 forks source link

How to save a Pytorch-Forecasting model after training it #1286

Open c3-varun opened 1 year ago

c3-varun commented 1 year ago

Other libraries: PrettyTable-3.6.0 autopage-0.5.1 cliff-4.2.0 cmaes-0.9.1 cmd2-2.4.3 colorlog-6.7.0 optuna-2.10.1 pandas-1.5.3 pbr-5.11.1 pyperclip-1.8.2 pytorch-lightning-1.9.4 scikit-learn-1.1.3 scipy-1.10.1 stevedore-5.0.0

Expected behavior

I executed code torch.save(tft, 'Baselining-48-720-1-720.pth') in order to save my model along with its weights and I expected the file to save.

I was able to save weights using torch.save(tft.state_dict(), 'Baselining-48-720-1-720.pth'), but that doesn't save the network.

Actual behavior

I got the following error:

AttributeError Traceback (most recent call last)

in ----> 1 torch.save(tft, 'Baselining-48-720-1-720.pth') ~/.local/lib/python3.8/site-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization) 421 if _use_new_zipfile_serialization: 422 with _open_zipfile_writer(f) as opened_zipfile: --> 423 _save(obj, opened_zipfile, pickle_module, pickle_protocol) 424 return 425 else: ~/.local/lib/python3.8/site-packages/torch/serialization.py in _save(obj, zip_file, pickle_module, pickle_protocol) 633 pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) 634 pickler.persistent_id = persistent_id --> 635 pickler.dump(obj) 636 data_value = data_buf.getvalue() 637 zip_file.write_record('data.pkl', data_value, len(data_value)) AttributeError: Can't pickle local object 'TupleOutputMixIn.to_network_output..Output' ### Code to reproduce the problem My model is a TFT and I trained it following this tutorial (https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/stallion.html): ``` tft = TemporalFusionTransformer.from_dataset( training, learning_rate=0.03, hidden_size=16, attention_head_size=1, dropout=0.1, hidden_continuous_size=8, output_size=len(quantiles), loss=QuantileLoss(quantiles=quantiles), log_interval=10, # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches reduce_on_plateau_patience=4, ) ``` The following line failed: ``` torch.save(tft, 'Baselining-48-720-1-720.pth') ```
GhoulMac commented 8 months ago

I am using a windows machine and for saving the model tft as a pickle file worked. I was able to load the model and then run the .predict method. The required imports should be present in the environment.

I saved the model after the we get the best model i.e.

best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

Once we have the best Model,we can save the network and the weights using pickle

import pickle

with open("tft.pkl",'wb') as f:
    pickle.dump(best_tft)

To load the model

import pickle

with open("tft.pkl",'rb') as f:
    model=pickle.load(f)
ivanightingale commented 5 months ago

I'm currently using pytorch-forecasting 1.0.0 and have the same problem when trying to pickle a model like TemporalFusionTransformer.

The problem seems to be that one of its super classes TupleOutputMixIn has a local method in its to_network_output() function.