Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.03k stars 3.36k forks source link

MetricTracker that also logs the maximum/minimum values #19070

Closed crazyboy9103 closed 10 months ago

crazyboy9103 commented 10 months ago

Description & Motivation

It is sometimes important to monitor minimum/maximum values for results that is logged via self.log, e.g. losses, metrics. Although ModelCheckpoint saves checkpoints using "monitor", it seems convenient to log the min/max values and see how they change over time in Tensorboard/Wandb and others.

Pitch

No-brainer version i'm using currently

from typing import List, Dict, Literal

import lightning.pytorch as pl

class MetricTracker(pl.Callback):
    r"""
    Automatically logs the maximum/minimum value of a metric. 

    Args:
        metric_config: 
            example = [
                {
                    "name": "train/ssl-loss",
                    "mode": "min",
                    "interval": "step",
                },
                {
                    "name": "valid/online-linear-accuracy",
                    "mode": "max",
                    "interval": "epoch",
                },
            ]

    Example::
        >>> from lightning.pytorch import Trainer
        >>> tracker = MetricTracker(example)
        >>> trainer = Trainer(callbacks=[tracker])
    """
    def __init__(self, metric_config: List[Dict[str, str]]):
        super().__init__()
        for config in metric_config:
            assert "name" in config
            assert "mode" in config and config["mode"] in ["max", "min"]
            assert "interval" in config and config["interval"] in ["epoch", "step"]
            config.setdefault("value", -1e8 if config["mode"] == "max" else 1e8)

        self.epoch_config = [config for config in metric_config if config["interval"] == "epoch"]
        self.step_config = [config for config in metric_config if config["interval"] == "step"]

    def _log_metrics(self, trainer, interval: Literal["step", "epoch"]="step"):
        metrics_to_log = {}

        config = self.step_config if interval == "step" else self.epoch_config
        metrics = trainer.callback_metrics
        for cfg in config:
            name = cfg["name"]
            mode = cfg["mode"]

            if name in metrics:
                cfg["value"] = max(cfg["value"], metrics[name]) if mode == "max" else min(cfg["value"], metrics[name])
                metrics_to_log[f"{mode}_{name}"] = cfg["value"]

        if metrics_to_log and trainer.logger:
            trainer.logger.log_metrics(metrics_to_log, step=trainer.global_step if interval == "step" else trainer.current_epoch)

    def on_train_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx):
        self._log_metrics(trainer, interval="step")

    def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self._log_metrics(trainer, interval="epoch")

    def on_validation_batch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx, dataloader_idx = 0):
        self._log_metrics(trainer, interval="step")

    def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self._log_metrics(trainer, interval="epoch")

Alternatives

No response

Additional context

No response

cc @borda @awaelchli

awaelchli commented 10 months ago

@crazyboy9103 Wouldn't it be easier to do this?


# Log the metric
self.log("my_metric", value)

# Also log min and max
self.log("my_metric_min", value, reduce_fx="min")
self.log("my_metric_max", value, reduce_fx="max")
crazyboy9103 commented 10 months ago

I simply wasn't aware of this functionality, as it wasn't well explained in the doc. It is definitely easier, thank you for pointing out the detail in the doc.