Open c3-varun opened 1 year ago
I am using a windows machine and for saving the model tft as a pickle file worked. I was able to load the model and then run the .predict
method. The required imports should be present in the environment.
I saved the model after the we get the best model i.e.
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
Once we have the best Model,we can save the network and the weights using pickle
import pickle
with open("tft.pkl",'wb') as f:
pickle.dump(best_tft)
To load the model
import pickle
with open("tft.pkl",'rb') as f:
model=pickle.load(f)
I'm currently using pytorch-forecasting 1.0.0 and have the same problem when trying to pickle a model like TemporalFusionTransformer
.
The problem seems to be that one of its super classes TupleOutputMixIn
has a local method in its to_network_output()
function.
Other libraries: PrettyTable-3.6.0 autopage-0.5.1 cliff-4.2.0 cmaes-0.9.1 cmd2-2.4.3 colorlog-6.7.0 optuna-2.10.1 pandas-1.5.3 pbr-5.11.1 pyperclip-1.8.2 pytorch-lightning-1.9.4 scikit-learn-1.1.3 scipy-1.10.1 stevedore-5.0.0
Expected behavior
I executed code
torch.save(tft, 'Baselining-48-720-1-720.pth')
in order to save my model along with its weights and I expected the file to save.I was able to save weights using
torch.save(tft.state_dict(), 'Baselining-48-720-1-720.pth')
, but that doesn't save the network.Actual behavior
I got the following error:
AttributeError Traceback (most recent call last)