unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.91k stars 858 forks source link

[QUESTION] Is there ability to continue training of model from checkpoint? #1400

Closed mzillag closed 1 year ago

mzillag commented 1 year ago
  1. Is there ability to continue training of model from checkpoint?
  2. I'm training TFT model on multi-time series dataset, and for past covariates I'm using 3 covariates which changing with a time (known at any time) and 11 which is not changing with a time. Is that proper way to use this data, or is there the other correct way?
dennisbader commented 1 year ago

Hey @mzillag and sorry for the late reply.

  1. Yes, you can use the load_from_checkpoint() method. I.e. TFTModel.load_from_ckeckpoint(), see the docs here
  2. For multiple time series (i.e. a list of univariate or multivariate target series), you must pass a list of series as past covariates that has the same number of elements as your target series. Each element/series is a multivariate time series with the 3 covariate components. The 11 covariates that are not changing with time (static), can be used as static covariates embedded in each of your target series. See for example TimeSeries methods with_static_covariates() or from_group_dataframe()
mzillag commented 1 year ago

Hey @mzillag and sorry for the late reply.

  1. Yes, you can use the load_from_checkpoint() method. I.e. TFTModel.load_from_ckeckpoint(), see the docs here
  2. For multiple time series (i.e. a list of univariate or multivariate target series), you must pass a list of series as past covariates that has the same number of elements as your target series. Each element/series is a multivariate time series with the 3 covariate components. The 11 covariates that are not changing with time (static), can be used as static covariates embedded in each of your target series. See for example TimeSeries methods with_static_covariates() or from_group_dataframe()

Thanks for the answer. So, can I do something like this (I need to perform custom validation for my dataset):

darts_model = TFTModel(input_chunk_length=12,
                       output_chunk_length=4,
                       add_relative_index=True,
                       work_dir='darts_models',
                       save_checkpoints=False,
                       log_tensorboard=False,
                       torch_metrics=torchmetrics.MetricCollection([torchmetrics.MeanAbsolutePercentageError(),
                                                                        torchmetrics.MeanAbsoluteError()]))

for i in range(30):
    print(f'\n\n ------------------------------ Epoch{i+1} ------------------------------ \n\n')
    darts_model.fit(series=x_train,
                    past_covariates=y_train, epochs=1,
                    )
dennisbader commented 1 year ago

Loading a model from checkpoint is supposed to be used in two cases (in Darts):

Training iteratively n-times for one epoch will not give you the same results as if you fit n-epochs in one fit() call.

Have you considered using a validation set when calling fit()? Here is the documentation. Is your custom validation something you can address with a PyTorch Lightning callback that you can pass at model creation with pl_trainer_kwargs? See the docs here.

mzillag commented 1 year ago

Loading a model from checkpoint is supposed to be used in two cases (in Darts):

  • if there was an error during training, we can resume training (up until the n epochs specified when training the model for the first time) from a checkpoint before the crash.
  • load the model from a checkpoint for inference (not training again)

Training iteratively n-times for one epoch will not give you the same results as if you fit n-epochs in one fit() call.

Have you considered using a validation set when calling fit()? Here is the documentation. Is your custom validation something you can address with a PyTorch Lightning callback that you can pass at model creation with pl_trainer_kwargs? See the docs here.

No, unfortunately I can't address it with a PyTorch Lightning callback. So, there's no method to train model in way like I mentioned before?

dennisbader commented 1 year ago

I was interested in this question, so I gave it a go. I found a solution:

With this, we can leverage the Darts API also during model training

from typing import Optional

import numpy as np
import pandas as pd
from pytorch_lightning.callbacks import Callback, ModelCheckpoint

from darts import TimeSeries
from darts.metrics import rmse
from darts.models import TFTModel
from darts.models.forecasting.torch_forecasting_model import TorchForecastingModel
from darts.utils import timeseries_generation as tg

# generate time series
ts = tg.sine_timeseries(length=55)
ts_train, ts_val = ts[:50], ts[50:]
covs = tg.linear_timeseries(length=60)

model_name = "tft_test"
n_epochs = 10
seed = 123  # for reproducible results

def evaluate_model(model):
    # do custom validation here, i.e.:
    # evaluate model by taking RMSE between val series and 0.5 quantile form stochastic predictions
    return rmse(
        actual_series=ts_val,
        pred_series=model.predict(n=5, series=ts_train, future_covariates=covs, num_samples=100, verbose=False)
    )

def load_and_evaluate(epoch):
    # load checkpoint from last epoch and evaluate
    model = TorchForecastingModel.load_from_checkpoint(model_name, best=False)
    logs = {
        "rmse": evaluate_model(model),
        "epoch": epoch
    }
    print(logs)
    return logs

# I couldn't find a callback hook which acts after saving the checkpoint, so we need to do it in two steps:
# 1) on_train_epoch_start() will load the checkpoint from the previous epoch. With this we can't load the 
# the final epoch
# 2) with on_train_end() we can also load the state of the final epoch
class EvalCallback(Callback):
    def __init__(self):
        self.logs = []

    def on_train_epoch_start(self, trainer, pl_module) -> None:
        # skip the first epoch as no checkpoint has been saved
        if trainer.current_epoch != 0:
            self.logs.append(load_and_evaluate(trainer.current_epoch - 1))

    def on_train_end(self, trainer, pl_module) -> None:
        # do the same at the end of training to also evaluate the final epoch
        self.logs.append(load_and_evaluate(trainer.current_epoch - 1))

# create a model including the epoch wise evaluation logger
eval_logger = EvalCallback()
model1 = TFTModel(
    input_chunk_length=10,
    output_chunk_length=10,
    random_state=seed,
    save_checkpoints=True,  # enable checkpointing
    model_name=model_name,  # give model name for checkpointing
    force_reset=True,  # delete previous save files
    pl_trainer_kwargs={
        "enable_progress_bar": False,  # disable verbosity for demonstration
        "enable_model_summary": False,
        "callbacks": [eval_logger]
    }
)

model1.fit(ts_train, future_covariates=covs, epochs=n_epochs, verbose=False)

rmse_model1 = eval_logger.logs[-1]["rmse"]
# rmse_model1 (loaded from a checkpoint) is equal to evaluating the fitted model for the first time
assert rmse_model1 == evaluate_model(model1)
# calling evaluate a second time gives different results due to probabilistic sampling
assert rmse_model1 != evaluate_model(model1)

# compare results to fitting a model without
model2 = TFTModel(
    input_chunk_length=10,
    output_chunk_length=10,
    random_state=seed,
    pl_trainer_kwargs={
        "enable_progress_bar": False,
        "enable_model_summary": False,
    }
)

model2.fit(ts_train, future_covariates=covs, epochs=n_epochs)

rmse_model2 = evaluate_model(model2)
# calling evaluate a second time gives different results due to probabilistic sampling
assert rmse_model2 != evaluate_model(model2)

# model results will differ only slightly due to probabilistic sampling of predict()
print(rmse_model1, rmse_model2)