unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
8.12k stars 884 forks source link

Confused about saving checkpoint, loading and predicting #2506

Closed valentin-fngr closed 3 months ago

valentin-fngr commented 3 months ago

Hi,

I am confused about checkpoint saving and loading. I have set a training using the TFT model such as :

my_model_AMD = TFTModel(
    input_chunk_length=input_chunk_length,
    output_chunk_length=forecast_horizon,
    hidden_size=512,
    lstm_layers=1,
    num_attention_heads=8,
    dropout=0.12,
    batch_size=128,
    n_epochs=300,   
    add_relative_index=False,
    add_encoders=None,
    likelihood=None,  # QuantileRegression is set per default
    loss_fn=nn.MSELoss(),
    random_state=42,
    optimizer_kwargs={
        "lr": 1e-2
    },
    pl_trainer_kwargs={
        "accelerator": "gpu",
        "devices": [0], 
        # "callbacks": EarlyStopping
    },
    lr_scheduler_cls = lr_scheduler_cls, 
    lr_scheduler_kwargs = lr_scheduler_kwargs,
    work_dir=work_dir,
    force_reset=True,
    log_tensorboard=True,
    model_name=f"TFT_nasdaq_2_2",
    save_checkpoints=True, 
    # variable_selection_module = _VariableSelectionNetwork
)

my_model_AMD.fit(
    series=train_target, 
    future_covariates=train_future_cov, 
    past_covariates=train_past_cov,
    val_series=val_target, 
    val_future_covariates=val_future_cov,
    val_past_covariates=val_past_cov,
    verbose=False, 
)

which generates logs as well as saved checkpoints here :

image

I stop the training and load the weights as followed :

my_model_AMD.load_from_checkpoint(model_name="TFT_nasdaq_2_2", best=True, file_name="/workspace/tft_experiment/custom_dart/examples/darts_logs/TFT_nasdaq_2_2/checkpoints/best-epoch=19-val_loss=0.01.ckpt")

which works.

Now, when running the backtest following :

backtest_series = my_model_AMD.historical_forecasts(
    target_transformed,
    future_covariates=future_cov_transformed,
    past_covariates=past_cov_transformed,
    start=training_cutoff,
    num_samples=1,
    forecast_horizon=forecast_horizon,
    stride=forecast_horizon,
    last_points_only=False,
    retrain=False,
    verbose=True,
)

I have the following error :

AttributeError: 'NoneType' object has no attribute 'set_predict_parameters'

Maybe I am not understanding how to load weights in Darts.

Thanks.

dennisbader commented 3 months ago

load_from_checkpoint() returns an object of the loaded model, so you have to assign it to a variable.

model_loaded = TFTModel.load_from_checkpoint(model_name="TFT_nasdaq_2_2", best=True)