Open ZhangAllen98 opened 4 days ago
Hi @ZhangAllen98,
The checkpoints for Darts deep learning models are split into two files; the "Darts" model (which handles all the timeseries specific logic, .pt
) and the weights themself (.ckpt
). Based on the error message you are getting, I am guessing that the weights (the "torch model") are not being loaded/exported properly.
Can you check if you are able to call predict()
after loading the model using load_from_checkpoint()
?
I am not sure of how to fix the way joblib pickles the model, maybe others users have some experience with that.
Internally we currently store two things when calling TorchForecastingModel.save()
(or with checkpoints directly):
TorchForecastingModel
object that is the Darts Framework-wrapper around the neural network.I assume joblib internally tries to pickle the object which will call TorchForecastingModel.__getstate__()
(saving) and TorchForecastingModel.__setstate__()
(loading).
We configured it so that when pickling the model, the actual TorchForecastingModel.model
(the neural network with all its parameters) attribute is not stored, since otherwise we would have to save it twice.
When loading with TorchForecastingModel.load*()
, it will retrieve the checkpoint file and add it the wrapper.
So I guess for this to work, you would have to store the two files (wrapper and model) with joblib, then load the wrapper and model separately, and in the end store the model under the wrapper. Something like this (haven't tested this though).
loaded_model = joblib.load('./wrapper_save)
loaded_checkpoint = joblib.load('./model_checkpoint_save)
loaded_model.model = loaded_checkpoint
I want to use
joblib
to dump and load the fitted model, but when I used the loaded model, it reportedAttributeError: 'NoneType' object has no attribute 'set_predict_parameters'
.Suppose that I have trained a TSMixerModel as follows,
and I want to save the fitted model using
joblib
(I know that we can find the checkpoint file from ./darts_logs directory ),but when I load the model to predict, it reports error
AttributeError: 'NoneType' object has no attribute 'set_predict_parameters'
How should do to properly save and load the fitted model.