Closed mzillag closed 1 year ago
Hey @mzillag and sorry for the late reply.
load_from_checkpoint()
method. I.e. TFTModel.load_from_ckeckpoint()
, see the docs hereHey @mzillag and sorry for the late reply.
- Yes, you can use the
load_from_checkpoint()
method. I.e.TFTModel.load_from_ckeckpoint()
, see the docs here- For multiple time series (i.e. a list of univariate or multivariate target series), you must pass a list of series as past covariates that has the same number of elements as your target series. Each element/series is a multivariate time series with the 3 covariate components. The 11 covariates that are not changing with time (static), can be used as static covariates embedded in each of your target series. See for example TimeSeries methods with_static_covariates() or from_group_dataframe()
Thanks for the answer. So, can I do something like this (I need to perform custom validation for my dataset):
darts_model = TFTModel(input_chunk_length=12,
output_chunk_length=4,
add_relative_index=True,
work_dir='darts_models',
save_checkpoints=False,
log_tensorboard=False,
torch_metrics=torchmetrics.MetricCollection([torchmetrics.MeanAbsolutePercentageError(),
torchmetrics.MeanAbsoluteError()]))
for i in range(30):
print(f'\n\n ------------------------------ Epoch{i+1} ------------------------------ \n\n')
darts_model.fit(series=x_train,
past_covariates=y_train, epochs=1,
)
Loading a model from checkpoint is supposed to be used in two cases (in Darts):
Training iteratively n-times for one epoch will not give you the same results as if you fit n-epochs in one fit()
call.
Have you considered using a validation set when calling fit()
? Here is the documentation.
Is your custom validation something you can address with a PyTorch Lightning callback that you can pass at model creation with pl_trainer_kwargs
? See the docs here.
Loading a model from checkpoint is supposed to be used in two cases (in Darts):
- if there was an error during training, we can resume training (up until the n epochs specified when training the model for the first time) from a checkpoint before the crash.
- load the model from a checkpoint for inference (not training again)
Training iteratively n-times for one epoch will not give you the same results as if you fit n-epochs in one
fit()
call.Have you considered using a validation set when calling
fit()
? Here is the documentation. Is your custom validation something you can address with a PyTorch Lightning callback that you can pass at model creation withpl_trainer_kwargs
? See the docs here.
No, unfortunately I can't address it with a PyTorch Lightning callback. So, there's no method to train model in way like I mentioned before?
I was interested in this question, so I gave it a go. I found a solution:
With this, we can leverage the Darts API also during model training
from typing import Optional
import numpy as np
import pandas as pd
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from darts import TimeSeries
from darts.metrics import rmse
from darts.models import TFTModel
from darts.models.forecasting.torch_forecasting_model import TorchForecastingModel
from darts.utils import timeseries_generation as tg
# generate time series
ts = tg.sine_timeseries(length=55)
ts_train, ts_val = ts[:50], ts[50:]
covs = tg.linear_timeseries(length=60)
model_name = "tft_test"
n_epochs = 10
seed = 123 # for reproducible results
def evaluate_model(model):
# do custom validation here, i.e.:
# evaluate model by taking RMSE between val series and 0.5 quantile form stochastic predictions
return rmse(
actual_series=ts_val,
pred_series=model.predict(n=5, series=ts_train, future_covariates=covs, num_samples=100, verbose=False)
)
def load_and_evaluate(epoch):
# load checkpoint from last epoch and evaluate
model = TorchForecastingModel.load_from_checkpoint(model_name, best=False)
logs = {
"rmse": evaluate_model(model),
"epoch": epoch
}
print(logs)
return logs
# I couldn't find a callback hook which acts after saving the checkpoint, so we need to do it in two steps:
# 1) on_train_epoch_start() will load the checkpoint from the previous epoch. With this we can't load the
# the final epoch
# 2) with on_train_end() we can also load the state of the final epoch
class EvalCallback(Callback):
def __init__(self):
self.logs = []
def on_train_epoch_start(self, trainer, pl_module) -> None:
# skip the first epoch as no checkpoint has been saved
if trainer.current_epoch != 0:
self.logs.append(load_and_evaluate(trainer.current_epoch - 1))
def on_train_end(self, trainer, pl_module) -> None:
# do the same at the end of training to also evaluate the final epoch
self.logs.append(load_and_evaluate(trainer.current_epoch - 1))
# create a model including the epoch wise evaluation logger
eval_logger = EvalCallback()
model1 = TFTModel(
input_chunk_length=10,
output_chunk_length=10,
random_state=seed,
save_checkpoints=True, # enable checkpointing
model_name=model_name, # give model name for checkpointing
force_reset=True, # delete previous save files
pl_trainer_kwargs={
"enable_progress_bar": False, # disable verbosity for demonstration
"enable_model_summary": False,
"callbacks": [eval_logger]
}
)
model1.fit(ts_train, future_covariates=covs, epochs=n_epochs, verbose=False)
rmse_model1 = eval_logger.logs[-1]["rmse"]
# rmse_model1 (loaded from a checkpoint) is equal to evaluating the fitted model for the first time
assert rmse_model1 == evaluate_model(model1)
# calling evaluate a second time gives different results due to probabilistic sampling
assert rmse_model1 != evaluate_model(model1)
# compare results to fitting a model without
model2 = TFTModel(
input_chunk_length=10,
output_chunk_length=10,
random_state=seed,
pl_trainer_kwargs={
"enable_progress_bar": False,
"enable_model_summary": False,
}
)
model2.fit(ts_train, future_covariates=covs, epochs=n_epochs)
rmse_model2 = evaluate_model(model2)
# calling evaluate a second time gives different results due to probabilistic sampling
assert rmse_model2 != evaluate_model(model2)
# model results will differ only slightly due to probabilistic sampling of predict()
print(rmse_model1, rmse_model2)