unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
8.05k stars 878 forks source link

Adding custom parameters in the Darts TFT model #1832

Closed sargamg99 closed 9 months ago

sargamg99 commented 1 year ago

Hello! I have created two TFT models and can access their parameters. Then, I have performed some operations their parameters and now I want to load these new custom parameters to TFT Model. I tried using " modelk.model.load_state_dict(newparams) " but shows an error " AttributeError: 'NoneType' object has no attribute 'load_state_dict' ". How do I resolve this issue?

madtoinou commented 1 year ago

Hi @sargamg99,

Sharing a code snippet (for the model saving/loading) would make it easier to help you however, I think that your problem comes from the fact that model is instantiated during the first fit() call (to be able to detect to adapt some parameters based on the data). You can however use model._setup_for_train(...) to instantiate the model prior to loading your parameters.

Let me know if it solves the issue.

sargamg99 commented 1 year ago

No. This isn't working. It is showing the same error.

sdA = modelm.model.state_dict() #modelm and my_model are two TFT models trained on some data sdB = my_model.model.state_dict()

for key in sdA: sdB[key] = (sdB[key] + sdA[key]) / 2 modelk = TFTModel( input_chunk_length=input_chunk_length, output_chunk_length=forecast_horizon ) modelk.model._setup_for_train() #I might not have called the function in the right way. modelk.model.load_state_dict(sdB)

ERROR : AttributeError: 'NoneType' object has no attribute '_setup_for_train'

madtoinou commented 1 year ago

Some parameters are indeed missing. I wrote a minimal example with the NBEATSModel but it should be exactly the same for the TFTModel:

# modelA and modelB where previously trained on a simple univariate series, without covariates
sdA = modelA.model.state_dict()
sdB = modelB.model.state_dict()

new_state_dict = {}
for key in sdA:
    new_state_dict[key] = (sdB[key] + sdA[key]) / 2

# clean approach, guarantes that the model can be train with the series ts
modelk = NBEATSModel(
input_chunk_length=4,
output_chunk_length=2
)
# the inputs should be identical to those of modelA.fit() and modelB.fit()
train_ds = modelk._build_train_dataset(target=ts,
                                       past_covariates=None,
                                       future_covariates=None,
                                       max_samples_per_ts=None)
trainer, new_model, train_loader, val_loader = modelk._setup_for_train(train_ds)
modelk.model = new_model
modelk.model.load_state_dict(new_state_dict)

# a simpler approach but without any sanity check
modelk = NBEATSModel(
input_chunk_length=4,
output_chunk_length=2
)
modelk.train_sample = modelA.train_sample
modelk.output_dim = modelA.output_dim
modelk.model = modelk._init_model()
modelk.model.load_state_dict(new_state_dict)

Make sure to close this issue if this snippet solve your problem