facebookresearch / vissl

VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
https://vissl.ai
MIT License
3.24k stars 330 forks source link

How to build an Early Stopping Hook #531

Open Pedrexus opened 2 years ago

Pedrexus commented 2 years ago

🚀 Feature

I wish to integrate Early Stopping into VISSL

Motivation & Examples

Early Stopping is an useful mechanism, already integrated in several libraries and frameworks, which can help when training several models for many epochs.

https://en.wikipedia.org/wiki/Early_stopping

Note

This is actually an ask for assistance, as I already have a working Early Stopping Hook, but it has not been very reliable in Multi GPU scenarios, in which the training just gets stuck when stop. Could you help me solve this problem?

Example of my early_stopping_hook.py file:

from pathlib import Path
import numpy as np
from structlog import get_logger

from classy_vision.hooks import ClassyHook
from classy_vision.hooks.classy_hook import ClassyHook
from vissl.utils.io import create_file_symlink
from vissl.utils.env import get_machine_local_and_dist_rank

from common.tools import PersistentDict  # https://code.activestate.com/recipes/576642/

INF = float("inf")

logger = get_logger()

def is_primary():
    dist_rank = get_machine_local_and_dist_rank()[1]
    return dist_rank == 0

class EarlyStoppingHook(ClassyHook):
    on_start = ClassyHook._noop
    on_forward = ClassyHook._noop
    on_backward = ClassyHook._noop
    on_update = ClassyHook._noop
    on_step = ClassyHook._noop
    on_end = ClassyHook._noop

    hook_filename = "early_stopping.json"
    pers_format = "json"

    def __init__(self, cfg) -> None:
        super().__init__()

        setup = concfgfig.HOOKS.EARLY_STOPPING_SETUP

        self.min_delta: float = setup.MIN_DELTA
        self.patience: int = setup.PATIENCE
        self.warmup: int = setup.WARMUP_RANGE
        self.is_self_supervised = "cross_entropy" not in cfg.LOSS.name
        self.cache_filepath = Path(cfg.CHECKPOINTS.DIR) / self.hook_filename

        logger.info(f"{EarlyStoppingHook.__class__.__name__} enabled")

    def cache(self, **kwargs):
        return PersistentDict(
            str(self.cache_filepath),
            format=self.pers_format,
            flag="c" if is_primary() else "r",
            **kwargs,
        )

    def set_cache_to_finished(self):
        with self.cache() as db:
            db["finished"] = True

    @property
    def is_finished(self) -> bool:
        with self.cache() as db:
            return db.get("finished")

    def stop_task(self, task):
        with self.cache() as db:
            # task.train_phase_idx = db["min_loss_epoch"]
            # task.num_train_phases = db["min_loss_epoch"] + 1
            task.phases = []

        logger.warning("task stopped", min_loss_epoch=db["min_loss_epoch"])

    def set_final_checkpoint(self):
        dir = self.cache_filepath.parent

        with self.cache() as db:
            checkpoint_filepath = (
                dir / f"model_final_checkpoint_phase{db['min_loss_epoch']}.torch"
            )

        source_file = dir / f'model_phase{db["min_loss_epoch"]}.torch'

        if source_file.is_file():
            source_file.rename(checkpoint_filepath)

        logger.warning("checkpoint renamed", checkpoint_filepath=checkpoint_filepath)

    def update_symlink_to_final_checkpoint(self):
        # raise NotImplementedError("don't use this, breaks vissl")
        dir = self.cache_filepath.parent
        symlink_dest_file = str(dir / "checkpoint.torch")

        with self.cache() as db:
            source_file = str(
                dir
                / f'model_final_min_loss_checkpoint_phase{db["min_loss_epoch"]}.torch'
            )

        create_file_symlink(source_file, symlink_dest_file)

        logger.warning("symlink updated", source_file=source_file)

    def stop(self, task: "tasks.ClassyTask"):
        self.stop_task(task)

        if is_primary() and not self.is_finished:
            self.set_cache_to_finished()
            self.set_final_checkpoint()
            self.update_symlink_to_final_checkpoint()

    def on_start(self, task: "tasks.ClassyTask") -> None:
        """Prevents experiment from running again"""
        if self.is_finished:
            self.stop(task)

    def on_phase_start(self, task: "tasks.ClassyTask") -> None:
        """Called at the start of each phase."""
        # this is going to be used by non-primary gpus on phase start
        if not is_primary():
            return

        if self.is_finished:
            self.stop(task)

        self.phase_testing_losses = []

    def on_loss_and_meter(self, task: "tasks.ClassyTask") -> None:
        """Logs testing loss"""
        if not is_primary() or task.train:
            return

        loss = task.last_batch.loss.data.cpu().item()
        self.phase_testing_losses.append(loss)

    def on_phase_end(self, task: "tasks.ClassyTask") -> None:
        if (
            not is_primary()
            or task.train
            or self.is_self_supervised
            or self.warmup >= task.train_phase_idx
        ):
            return

        val_loss = np.mean(self.phase_testing_losses)
        epoch = task.train_phase_idx

        with self.cache() as db:
            if val_loss < db.get("min_loss", INF):
                db["min_loss"] = val_loss
                db["min_loss_epoch"] = epoch
                return

            loss_delta = abs(db["min_loss"] - val_loss)
            epoch_diff = abs(epoch - db["min_loss_epoch"])

        if loss_delta > self.min_delta or epoch_diff > self.patience:
            self.stop(task)
            logger.warning("early stopping")
            return

        logger.info("skipped early stopping")

The PersistentDict works just like python builtin shelve.

I believe it might just require tweaking the .stop_task(task) method, but I have not been able to do it until now.

QuentinDuval commented 2 years ago

Hi @Pedrexus,

First of all, thanks a lot for this ! I think this is a super important feature that is indeed really missing from VISSL.

Now, reading at the code, and based on the description of non reliable multi-node training, I think that the issue might come with the is_primary function usage and the PersistentDict.

What might happen is a race condition that leads of the the worker not to quit and get stuck, because of the way the multi-node training is done, it proceeds in locksteps where each worker has synchronisation points with the others.

I propose we try something which relies on standard PyTorch and might make it work better:

Could you please try something like this and tell me what happens?

Thank you again, Quentin

Pedrexus commented 2 years ago

Hello @QuentinDuval,

thanks for the reply. This is a good idea, and I will try to implement it as soon as possible.

Thanks, Pedro