unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
8.11k stars 884 forks source link

How to use dump or load of joblib to save or reuse fitted model #2531

Closed ZhangAllen98 closed 1 month ago

ZhangAllen98 commented 2 months ago

I want to use joblib to dump and load the fitted model, but when I used the loaded model, it reported AttributeError: 'NoneType' object has no attribute 'set_predict_parameters'.

Suppose that I have trained a TSMixerModel as follows,

model = TSMixerModel(
    **create_params(
        input_chunk_length,
        output_chunk_length,
        full_training=full_training,
    ),
    hidden_size=64,
    ff_size=64,
    num_blocks=2,
    activation="GELU",
    dropout=0.2,
    use_static_covariates=True,
    model_name='tsm',
)

model.fit(
    series=y_train,
    future_covariates=x_train,
    val_series=y_val,
    val_future_covariates=x_val,
)

fitted_model = model.load_from_checkpoint(
    model_name=model.model_name, best=True
)

and I want to save the fitted model using joblib(I know that we can find the checkpoint file from ./darts_logs directory ),

joblib.dump(fitted_model,'./test_tsm_model_save')

but when I load the model to predict, it reports error AttributeError: 'NoneType' object has no attribute 'set_predict_parameters'

loaded_model = joblib.load('./test_tsm_model_save)
y_pred = loaded_model.predict(n=output_chunk_length,
                           series=y_input,
                           past_covariates= None,
                           future_covariates=x_input,
                           num_samples=num_samples,
                           predict_likelihood_parameters= False,
                           )

How should do to properly save and load the fitted model.

madtoinou commented 2 months ago

Hi @ZhangAllen98,

The checkpoints for Darts deep learning models are split into two files; the "Darts" model (which handles all the timeseries specific logic, .pt) and the weights themself (.ckpt). Based on the error message you are getting, I am guessing that the weights (the "torch model") are not being loaded/exported properly.

Can you check if you are able to call predict() after loading the model using load_from_checkpoint()?

I am not sure of how to fix the way joblib pickles the model, maybe others users have some experience with that.

dennisbader commented 2 months ago

Internally we currently store two things when calling TorchForecastingModel.save() (or with checkpoints directly):

I assume joblib internally tries to pickle the object which will call TorchForecastingModel.__getstate__() (saving) and TorchForecastingModel.__setstate__() (loading).

We configured it so that when pickling the model, the actual TorchForecastingModel.model (the neural network with all its parameters) attribute is not stored, since otherwise we would have to save it twice.

When loading with TorchForecastingModel.load*(), it will retrieve the checkpoint file and add it the wrapper.

So I guess for this to work, you would have to store the two files (wrapper and model) with joblib, then load the wrapper and model separately, and in the end store the model under the wrapper. Something like this (haven't tested this though).

loaded_model = joblib.load('./wrapper_save)
loaded_checkpoint = joblib.load('./model_checkpoint_save)
loaded_model.model = loaded_checkpoint
ZhangAllen98 commented 2 months ago

Internally we currently store two things when calling TorchForecastingModel.save() (or with checkpoints directly):

    1. The TorchForecastingModel object that is the Darts Framework-wrapper around the neural network.
    1. The neural network checkpoint containing the model architecture parameters.

I assume joblib internally tries to pickle the object which will call TorchForecastingModel.__getstate__() (saving) and TorchForecastingModel.__setstate__() (loading).

We configured it so that when pickling the model, the actual TorchForecastingModel.model (the neural network with all its parameters) attribute is not stored, since otherwise we would have to save it twice.

When loading with TorchForecastingModel.load*(), it will retrieve the checkpoint file and add it the wrapper.

So I guess for this to work, you would have to store the two files (wrapper and model) with joblib, then load the wrapper and model separately, and in the end store the model under the wrapper. Something like this (haven't tested this though).

loaded_model = joblib.load('./wrapper_save)
loaded_checkpoint = joblib.load('./model_checkpoint_save)
loaded_model.model = loaded_checkpoint

Yes, in this case I first dump the model, and then dump the model.model and in inference stage, load these two files, it would work.

madtoinou commented 2 months ago

Can you try the code snippet and close this issue if this approach works?