unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.91k stars 858 forks source link

Saving/loading model with a checkpoint callback #2367

Closed tRosenflanz closed 4 months ago

tRosenflanz commented 4 months ago

Is your feature request related to a current problem? Please describe. Current torch model checkpointing logic is quite rigid. It only allows to track loss rather than other metrics and uses somewhat hardcoded directories which makes restoring the best model for a given metric challenging.

Describe proposed solution Either or/and

  1. allow metric/args override for the checkpointing callback when save_checkpoints is enabled,
  2. Modify load_from_checkpoint to be more flexible

Describe potential alternatives I worked around this issue by abusing os.path.join behavior with absolutes paths (must start with /) If a segment is an absolute path (which on Windows requires both a drive and a root), then all previous segments are ignored and joining continues from the absolute path segment.

#load best
import os
cwd = os.getcwd()
checkp_dir = os.path.dirname(checkpoint_callback.best_model_path)
m.save(os.path.join(checkp_dir,"_model.pth.tar"))
tft = TSMixerModel.load_from_checkpoint(
    "feel free to put anything here because it actually doesn't matter",
    work_dir=os.path.join(cwd,checkp_dir),  #absolute path, ignores internal path prefixes
    file_name=os.path.join(cwd,checkpoint_callback.best_model_path) #absolute path, ignores internal path prefixes
)

#copy to destination can be done with just save
name = 'darts_model'
destination = os.path.join(OUTPUT_DIR,name)
tft.save(destination)

Additional context Clearly the solution above is not ideal since it abuses somewhat hidden behavior of os.path.join and knowledge of the method internals. Perhaps it is the responsibility of the user to enable proper parameters when a custom trainer is used but load_from_checkpoint could accept explicit paths for model and checkpoint and skip the current directory/naming logic when absolute paths are given

madtoinou commented 4 months ago

Hi @tRosenflanz,

Not sure why the work_dir argument is not working for you, it should allow you to indicate a custom path for the automatic checkpoints when creating the torch model?

model = DLinearModel(input_chunk_length=4, output_chunk_length=1, save_checkpoints=True, work_dir="../custom_dir")

As for the monitored metrics, based on Pytorch-Lightning documentation, it should be possible to use other functions than the loss to identify the best checkpoint since all the torch_metrics are actually logged into the trainer:

from darts.models import DLinearModel
from torchmetrics import ExplainedVariance
from pytorch_lightning.callbacks import ModelCheckpoint

# Darts prepend "val_" and "train_" to the torch_metrics entries name 
checkpoint_callback = ModelCheckpoint(monitor='val_ExplainedVariance')

# Add the callback
model = DLinearModel(
    input_chunk_length=4,
    output_chunk_length=1,
    save_checkpoints=True,
    torch_metrics=ExplainedVariance(), 
    pl_trainer_kwargs={"callbacks":[checkpoint_callback]}
)

model.fit(train, val_series=val, epochs=10)
new_model = DLinearModel.load_from_checkpoint(model.model_name, best=True)

Please note that specifying the dirpath and filename arguments of ModelCheckpoint will export the .ckpt at the desired path but the .pth.tar will not be exported, making this checkpoint not usable unless the .pth.tar is copied to the expected relative path (we need to investigate if it can be automated if such a callback is provided). Thus making the work_dir argument to ideal way of changing the checkpoints path.

Manually saving the model as you described is probably the most practical way to copy the checkpoints to the desired place.

Let me know if it helps.

tRosenflanz commented 4 months ago

I think the automatic concatenation of different subpaths into work_dir and interactions with custom training confused me a bit. I am switching from pytorch_forecasting where I had to define my own trainer so the Darts way of handling it confused me at first. I see no issue with the method you provide.

eye4got commented 1 month ago

This doesn't appear to actually change the metric recorded? According to the definition of the best parameter from the load_from_checkpoint() method from the docs: best (bool) – If set, will retrieve the best model (according to validation loss) instead of the most recent one. Only is ignored when file_name is given.

Similarly the checkpoint files are still labelled: "best-epoch=24-val_loss=0.27" and none include the label "val_MeanAbsolutePercentageError" (which I tried to use).

I even set this boolean in case it made a difference (although it should be true by default)

checkpoint_callback = ModelCheckpoint(monitor='val_MeanAbsolutePercentageError', auto_insert_metric_name=True)
my_stopper = EarlyStopping(
        monitor="val_MeanAbsolutePercentageError",
        patience=400,
        min_delta=0.01,
        stopping_threshold=0.28,
        mode='min'
)

model_params['pl_trainer_kwargs'] = {'callbacks': [checkpoint_callback, my_stopper]}
madtoinou commented 1 month ago

Hi,

I did a mistake in my previous answer, some additional parameters (dirpath and filename) must be provided to the ModelCheckpoint so that the weights are properly exported:

import os
from torchmetrics import ExplainedVariance
from pytorch_lightning.callbacks import ModelCheckpoint

from darts import TimeSeries
from darts.models import DLinearModel
from darts.datasets import AirPassengersDataset
from darts.models.forecasting.torch_forecasting_model import _get_checkpoint_folder

# Read data
series = AirPassengersDataset().load()
series = series.astype("float32")

# Create training and validation sets:
train, val = series.split_after(0.7)

# Darts prepend "val_" and "train_" to the torch_metrics entries name 
checkpoint_callback = ModelCheckpoint(
    monitor='val_ExplainedVariance',
    filename='best-{epoch}-{val_ExplainedVariance:.2f}',
    dirpath= _get_checkpoint_folder(
        work_dir = os.path.join(os.getcwd(), "darts_logs"),
        model_name = "example_custom_ckpt",
    )
)

# Add the callback
model = DLinearModel(
    input_chunk_length=4,
    output_chunk_length=1,
    save_checkpoints=True,
    torch_metrics=ExplainedVariance(), 
    pl_trainer_kwargs={"callbacks":[checkpoint_callback]},
    model_name = "example_custom_ckpt"
)

model.fit(train, val_series=val, epochs=10)

By using this snippet, it will be obvious that two separate checkpoints (one for train loss and one for the explained variance) are created.

Note that you will have to specify the file_name argument of the load/load_from_checkpoint methods since there will be an ambiguity about which "best" checkpoint to load.