sktime / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.95k stars 625 forks source link

Clarification on MetricsCallback in optimize_hyperparameters() #652

Open georgeblck opened 3 years ago

georgeblck commented 3 years ago

Expected behavior

I am tuning a TFT model with the function optimize_hyperparameters. Each trial returns my specified metric (in this case MAE) like so:

[I 2021-08-20 18:39:06,537] Trial 219 finished with value: 11.637566566467285 and parameters: {'gradient_clip_val': 0.9433828074072016, 'hidden_size': 11, 'dropout': 0.15422157867252945, 'hidden_continuous_size': 8, 'attention_head_size': 2, 'learning_rate': 0.08458313458926149}. Best is trial 219 with value: 11.637566566467285.

I expect this returned value 11.63 to be from the best epoch in that trial. The callback used in the function also saves only the best epoch for each trial

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=os.path.join(model_path, "trial_{}".format(trial.number)), filename="{epoch}", monitor="val_loss"
)

Actual behavior

However, the best trial value reported by optuna is actually the value from the last epoch of that trial and so depends a lot on max_epochs. For each trial this is actually a good thing, e.g. for the trial above the best epoch had an MAE of 11,512 so it was better.

But for the entire optimization it is bad, because optuna does not get the actual best value of each trial.

As I see it, it all depends on this function that provides optuna with the final value:

class MetricsCallback(Callback):
    """PyTorch Lightning metric callback."""

    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_validation_end(self, trainer, pl_module):
        self.metrics.append(trainer.callback_metrics)

Maybe I do not understand these callbacks correctly. If so, please close this issue.

Also thanks a lot for this great package.

nicocheh commented 2 years ago

Having same situation, any news on this?

georgeblck commented 2 years ago

For the moment, I would just write your own optimize_hyperparameters function with an adjusted MetricsCallback that saves the minimal value and not the last.