unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.92k stars 860 forks source link

How to use dump or load of joblib to save or reuse fitted model #2531

Open ZhangAllen98 opened 4 days ago

ZhangAllen98 commented 4 days ago

I want to use joblib to dump and load the fitted model, but when I used the loaded model, it reported AttributeError: 'NoneType' object has no attribute 'set_predict_parameters'.

Suppose that I have trained a TSMixerModel as follows,

model = TSMixerModel(
    **create_params(
        input_chunk_length,
        output_chunk_length,
        full_training=full_training,
    ),
    hidden_size=64,
    ff_size=64,
    num_blocks=2,
    activation="GELU",
    dropout=0.2,
    use_static_covariates=True,
    model_name='tsm',
)

model.fit(
    series=y_train,
    future_covariates=x_train,
    val_series=y_val,
    val_future_covariates=x_val,
)

fitted_model = model.load_from_checkpoint(
    model_name=model.model_name, best=True
)

and I want to save the fitted model using joblib(I know that we can find the checkpoint file from ./darts_logs directory ),

joblib.dump(fitted_model,'./test_tsm_model_save')

but when I load the model to predict, it reports error AttributeError: 'NoneType' object has no attribute 'set_predict_parameters'

loaded_model = joblib.load('./test_tsm_model_save)
y_pred = loaded_model.predict(n=output_chunk_length,
                           series=y_input,
                           past_covariates= None,
                           future_covariates=x_input,
                           num_samples=num_samples,
                           predict_likelihood_parameters= False,
                           )

How should do to properly save and load the fitted model.

madtoinou commented 2 days ago

Hi @ZhangAllen98,

The checkpoints for Darts deep learning models are split into two files; the "Darts" model (which handles all the timeseries specific logic, .pt) and the weights themself (.ckpt). Based on the error message you are getting, I am guessing that the weights (the "torch model") are not being loaded/exported properly.

Can you check if you are able to call predict() after loading the model using load_from_checkpoint()?

I am not sure of how to fix the way joblib pickles the model, maybe others users have some experience with that.

dennisbader commented 2 days ago

Internally we currently store two things when calling TorchForecastingModel.save() (or with checkpoints directly):

I assume joblib internally tries to pickle the object which will call TorchForecastingModel.__getstate__() (saving) and TorchForecastingModel.__setstate__() (loading).

We configured it so that when pickling the model, the actual TorchForecastingModel.model (the neural network with all its parameters) attribute is not stored, since otherwise we would have to save it twice.

When loading with TorchForecastingModel.load*(), it will retrieve the checkpoint file and add it the wrapper.

So I guess for this to work, you would have to store the two files (wrapper and model) with joblib, then load the wrapper and model separately, and in the end store the model under the wrapper. Something like this (haven't tested this though).

loaded_model = joblib.load('./wrapper_save)
loaded_checkpoint = joblib.load('./model_checkpoint_save)
loaded_model.model = loaded_checkpoint