Open LogWell opened 1 year ago
@LogWell @Borda
I feel this feature can be accommodated in src/lightning/pytorch/callbacks/early_stopping.py
I have mentioned the comments below Change1 through Change5.
Please provide your feedback to incorporate them
Under init
def __init__(
self,
monitor: str,
min_delta: float = 0.0,
patience: int = 3,
verbose: bool = False,
mode: str = "min",
strict: bool = True,
check_finite: bool = True,
stopping_threshold: Optional[float] = None,
divergence_threshold: Optional[float] = None,
check_on_train_epoch_end: Optional[bool] = None,
log_rank_zero_only: bool = False,
use_percentage: bool = False, # Change 1: Added use_percentage parameter
):
super().__init__()
...
self.use_percentage = use_percentage # Change 2: Store the use_percentage parameter
self.min_delta_abs = min_delta # Change 3: Store the absolute min_delta value
...
under _run_early_stopping_check
def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
"""Checks whether the early stopping condition is met and if so tells the trainer to stop the training."""
logs = trainer.callback_metrics
if trainer.fast_dev_run or not self._validate_condition_metric(logs): # disable early_stopping with fast_dev_run
return
current = logs[self.monitor].squeeze()
# Change 4: Calculate min_delta as a percentage if use_percentage is True
if self.use_percentage and torch.isfinite(self.best_score) and self.best_score != 0: # Avoid division by zero
self.min_delta = self.min_delta_abs * abs(self.best_score) / 100.0
should_stop, reason = self._evaluate_stopping_criteria(current)
# stop every ddp process if any world process decides to stop
should_stop = trainer.strategy.reduce_boolean_decision(should_stop, all=False)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop:
self.stopped_epoch = trainer.current_epoch
if reason and self.verbose:
self._log_info(trainer, reason, self.log_rank_zero_only)
under _improvement_message
def _improvement_message(self, current: Tensor) -> str:
"""Formats a log message that informs the user about an improvement in the monitored score."""
if torch.isfinite(self.best_score):
if self.use_percentage: # Change 5: Format the improvement message differently if use_percentage is True
improvement = (current - self.best_score) / self.best_score * 100.0 # calculate the percentage improvement
msg = (
f"Metric {self.monitor} improved by {improvement:.3f}% >="
f" min_delta = {self.min_delta:.3f}%. New best score: {current:.3f}"
)
else:
msg = (
f"Metric {self.monitor} improved by {abs(self.best_score - current):.3f} >="
f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}"
)
else:
msg = f"Metric {self.monitor} improved. New best score: {current:.3f}"
return msg
@LogWell What would be the value of reference when the min_delta is a percentage? I imagine you would like to express that "stop training if the loss changes by less than 5%". This is relative to what? Is it the highest loss ever observed? Could you give us an example?
What I mean is to add different early stop methods like 1. abs<0.1: val_loss(cur)
,2. delta<0.002: val_loss(pre) - val_loss(cur)
, 3. ratio<5%: (val_loss(pre)-val_loss(cur))/val_loss(pre)
, I'm not sure if these methods are useful enough.
Description & Motivation
When measuring errors, there are absolute errors and relative errors. It is a good idea to monitor the percentage of
val_loss
reduction for early stopping!Pitch
No response
Alternatives
No response
Additional context
No response
cc @borda @carmocca @awaelchli