Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.37k stars 3.38k forks source link

`EarlyStopping` monitor `min_delta` by percentage #18012

Open LogWell opened 1 year ago

LogWell commented 1 year ago

Description & Motivation

When measuring errors, there are absolute errors and relative errors. It is a good idea to monitor the percentage of val_loss reduction for early stopping!

Pitch

No response

Alternatives

No response

Additional context

No response

cc @borda @carmocca @awaelchli

rjarun8 commented 1 year ago

@LogWell @Borda

I feel this feature can be accommodated in src/lightning/pytorch/callbacks/early_stopping.py

I have mentioned the comments below Change1 through Change5.

Please provide your feedback to incorporate them

Under init

def __init__(
        self,
        monitor: str,
        min_delta: float = 0.0,
        patience: int = 3,
        verbose: bool = False,
        mode: str = "min",
        strict: bool = True,
        check_finite: bool = True,
        stopping_threshold: Optional[float] = None,
        divergence_threshold: Optional[float] = None,
        check_on_train_epoch_end: Optional[bool] = None,
        log_rank_zero_only: bool = False,
        use_percentage: bool = False,  # Change 1: Added use_percentage parameter
    ):
        super().__init__()
        ...
        self.use_percentage = use_percentage  # Change 2: Store the use_percentage parameter
        self.min_delta_abs = min_delta  # Change 3: Store the absolute min_delta value

    ...

under _run_early_stopping_check

def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
        """Checks whether the early stopping condition is met and if so tells the trainer to stop the training."""
        logs = trainer.callback_metrics

        if trainer.fast_dev_run or not self._validate_condition_metric(logs):  # disable early_stopping with fast_dev_run
            return

        current = logs[self.monitor].squeeze()

        # Change 4: Calculate min_delta as a percentage if use_percentage is True
        if self.use_percentage and torch.isfinite(self.best_score) and self.best_score != 0:  # Avoid division by zero
            self.min_delta = self.min_delta_abs * abs(self.best_score) / 100.0

        should_stop, reason = self._evaluate_stopping_criteria(current)

        # stop every ddp process if any world process decides to stop
        should_stop = trainer.strategy.reduce_boolean_decision(should_stop, all=False)
        trainer.should_stop = trainer.should_stop or should_stop
        if should_stop:
            self.stopped_epoch = trainer.current_epoch
        if reason and self.verbose:
            self._log_info(trainer, reason, self.log_rank_zero_only)

under _improvement_message

def _improvement_message(self, current: Tensor) -> str:
        """Formats a log message that informs the user about an improvement in the monitored score."""
        if torch.isfinite(self.best_score):
            if self.use_percentage:  # Change 5: Format the improvement message differently if use_percentage is True
                improvement = (current - self.best_score) / self.best_score * 100.0  # calculate the percentage improvement
                msg = (
                    f"Metric {self.monitor} improved by {improvement:.3f}% >="
                    f" min_delta = {self.min_delta:.3f}%. New best score: {current:.3f}"
                )
            else:
                msg = (
                    f"Metric {self.monitor} improved by {abs(self.best_score - current):.3f} >="
                    f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}"
                )
        else:
            msg = f"Metric {self.monitor} improved. New best score: {current:.3f}"
        return msg
awaelchli commented 1 year ago

@LogWell What would be the value of reference when the min_delta is a percentage? I imagine you would like to express that "stop training if the loss changes by less than 5%". This is relative to what? Is it the highest loss ever observed? Could you give us an example?

LogWell commented 1 year ago

What I mean is to add different early stop methods like 1. abs<0.1: val_loss(cur),2. delta<0.002: val_loss(pre) - val_loss(cur), 3. ratio<5%: (val_loss(pre)-val_loss(cur))/val_loss(pre), I'm not sure if these methods are useful enough.