Closed ZhangAllen98 closed 1 month ago
Hi @ZhangAllen98,
The checkpoints for Darts deep learning models are split into two files; the "Darts" model (which handles all the timeseries specific logic, .pt
) and the weights themself (.ckpt
). Based on the error message you are getting, I am guessing that the weights (the "torch model") are not being loaded/exported properly.
Can you check if you are able to call predict()
after loading the model using load_from_checkpoint()
?
I am not sure of how to fix the way joblib pickles the model, maybe others users have some experience with that.
Internally we currently store two things when calling TorchForecastingModel.save()
(or with checkpoints directly):
TorchForecastingModel
object that is the Darts Framework-wrapper around the neural network.I assume joblib internally tries to pickle the object which will call TorchForecastingModel.__getstate__()
(saving) and TorchForecastingModel.__setstate__()
(loading).
We configured it so that when pickling the model, the actual TorchForecastingModel.model
(the neural network with all its parameters) attribute is not stored, since otherwise we would have to save it twice.
When loading with TorchForecastingModel.load*()
, it will retrieve the checkpoint file and add it the wrapper.
So I guess for this to work, you would have to store the two files (wrapper and model) with joblib, then load the wrapper and model separately, and in the end store the model under the wrapper. Something like this (haven't tested this though).
loaded_model = joblib.load('./wrapper_save)
loaded_checkpoint = joblib.load('./model_checkpoint_save)
loaded_model.model = loaded_checkpoint
Internally we currently store two things when calling
TorchForecastingModel.save()
(or with checkpoints directly):
- The
TorchForecastingModel
object that is the Darts Framework-wrapper around the neural network.
- The neural network checkpoint containing the model architecture parameters.
I assume joblib internally tries to pickle the object which will call
TorchForecastingModel.__getstate__()
(saving) andTorchForecastingModel.__setstate__()
(loading).We configured it so that when pickling the model, the actual
TorchForecastingModel.model
(the neural network with all its parameters) attribute is not stored, since otherwise we would have to save it twice.When loading with
TorchForecastingModel.load*()
, it will retrieve the checkpoint file and add it the wrapper.So I guess for this to work, you would have to store the two files (wrapper and model) with joblib, then load the wrapper and model separately, and in the end store the model under the wrapper. Something like this (haven't tested this though).
loaded_model = joblib.load('./wrapper_save) loaded_checkpoint = joblib.load('./model_checkpoint_save) loaded_model.model = loaded_checkpoint
Yes, in this case I first dump
the model
, and then dump
the model.model
and in inference stage, load
these two files, it would work.
Can you try the code snippet and close this issue if this approach works?
I want to use
joblib
to dump and load the fitted model, but when I used the loaded model, it reportedAttributeError: 'NoneType' object has no attribute 'set_predict_parameters'
.Suppose that I have trained a TSMixerModel as follows,
and I want to save the fitted model using
joblib
(I know that we can find the checkpoint file from ./darts_logs directory ),but when I load the model to predict, it reports error
AttributeError: 'NoneType' object has no attribute 'set_predict_parameters'
How should do to properly save and load the fitted model.