open-mmlab / mmengine

OpenMMLab Foundational Library for Training Deep Learning Models
https://mmengine.readthedocs.io/
Apache License 2.0
1.14k stars 340 forks source link

[Feature] Resume Validation Automatically #1394

Open collinmccarthy opened 11 months ago

collinmccarthy commented 11 months ago

What is the feature?

Suppose you're training with checkpointing. You save a checkpoint every 10 epochs and run validation every 10 epochs. After the 10th epoch you save a checkpoint before validation, start your validation loop, and the program crashes. Maybe you got preempted on a cluster, maybe you ran out of memory, maybe there was annotation issue, etc. Then you fix the bug if there was one and restart, but the validation loop gets skipped and training resumes on the 11th epoch.

I don't think this is a good idea, and I think it could be prevented by restructuring how/when the validation loop or check occurs. If validation crashes, when you resume you should automatically re-run validation.

Is this already implemented somewhere/somehow? If not, do you have any recommendations about how to go about it? I've hacked together something similar before in Detectron2 but I'd rather do it the "right way" here.

Any other context?

If it's not already implemented I thought about the following:

  1. When validation starts, write out a temporary file; if when resuming the file exists, restart validation
    • Pros: easy to check
    • Cons: have to remove the temporary file, have to handle I/O issues to be more robust
  2. At the start of a training epoch, check if validation should have been run, and if there are results in results.json or similar
    • Pros: easy to understand, robust
    • Cons: more difficult to implement and verify
collinmccarthy commented 11 months ago

I forgot this is how I did it before.

  1. When saving a resume checkpoint before a validation iteration, save metadata like needs_val=True, then after validation save a new checkpoint with needs_val=False, and check for this flag when resuming from a checkpoint
    • Pros: Easy to check/understand, robust
    • Cons: Requires saving two checkpoints on iterations you're resuming and validating which can be slow; requires some additional logic in a few places
collinmccarthy commented 11 months ago

I was able to get this working with approach 3 above but it required a few different changes:

This is what it ended up looking like. Here's the PreTrainValHook

from mmengine.registry import HOOKS
from mmengine.hooks import Hook

from ..runner import CustomRunner
from ..runner.loops import PreTrainValLoop

@HOOKS.register_module()
class PreTrainValHook(Hook):
    """Run validation before training begins.

    When resuming on an epoch where validation should have just run at the end of the previous
    epoch, we check if validation succeeded and if not (due to preemption or a bug) we re-run it.
        - Uses ..engine.hooks.CustomCheckpointHook to store the "needs_validation" metadata
        - Uses ..engine.runner.CustomRunner to store resumed_checkpoint filepath
        - See https://github.com/open-mmlab/mmengine/issues/1394
    """

    def __init__(self, partial_num_batches: int = 2):
        self._partial_num_batches = partial_num_batches
        super().__init__()

    def before_train(self, runner: CustomRunner):
        # Calling runner.val_loop will build ValLoop, and in turn, build runner.val_loop.dataloader
        #   by calling runner.build_dataloader() with the validation dataloader config
        val_loop = runner.val_loop

        # Check whether we need to do a full validation loop (from resuming after preemption)
        #   or a partial validation loop to test model/dataloader/evaluator
        full_validation = False
        if runner.resumed_checkpoint is not None:
            # Load checkpoint to CPU since we just need to check the metadata
            checkpoint = runner.load_checkpoint(
                filename=runner.resumed_checkpoint, map_location="cpu"
            )
            full_validation = checkpoint["meta"].get("needs_validation", False)

        if full_validation:
            val_loop.run()

and here's the CustomCheckpointHook

from typing import Optional, Union

from mmengine.runner.loops import IterBasedTrainLoop, EpochBasedTrainLoop
from mmengine.hooks import CheckpointHook
from mmengine.registry import HOOKS

DATA_BATCH = Optional[Union[dict, tuple, list]]

@HOOKS.register_module()
class CustomCheckpointHook(CheckpointHook):

    def _save_checkpoint(self, runner, maybe_needs_validation=True) -> None:
        """Save the checkpoint and add "needs_validation" metadata if calling before validation.
        Called from CheckpointHook.after_train_iter() and CheckpointHook.after_train_epoch().

        Must pass in maybe_needs_validation=False to save another resume checkpoint after val ends
        and indicate that validation has finished.
        """
        needs_validation = False
        if runner.val_loop is not None and maybe_needs_validation:
            if isinstance(runner.train_loop, IterBasedTrainLoop):
                # This hook called in IterBasedTrainLoop::run_iter() via 'after_train_train_iter'
                # Then iter is incremented by one, and IterBasedTrainLoop::run() calls validation

                # Same logic as in IterBasedTrainLoop::run() but with iter+1 since we're testing
                #   if we will need validation next iter (but we haven't incremented iter yet)
                next_iter = runner.iter + 1
                needs_validation = (
                    runner.val_loop is not None
                    and next_iter >= runner.val_begin
                    and next_iter % runner.val_interval == 0
                )
            elif isinstance(runner.train_loop, EpochBasedTrainLoop):
                # This hook called in EpochBasedTrainLoop::run_epoch() via 'after_train_epoch'
                # Then epoch is incremented by one, and EpochBasedTrainLoop::run() calls validation

                # Same logic as in EpochBasedTrainLoop::run() but with epoch+1 since we're testing
                #   if we will need validation next epoch (but we haven't incremented epoch yet)
                next_epoch = runner.epoch + 1
                needs_validation = (
                    runner.val_loop is not None
                    and next_epoch >= runner.val_begin
                    and next_epoch % runner.val_interval == 0
                )
            else:
                raise RuntimeError(
                    f"Expected type(runner.train_loop) in [IterBasedTrainLoop, EpochBasedTrainLoop]"
                    f", found {type(runner.train_loop)}"
                )

        if self.by_epoch:
            step = runner.epoch + 1
            meta = dict(epoch=step, iter=runner.iter, needs_validation=needs_validation)
        else:
            step = runner.iter + 1
            meta = dict(epoch=runner.epoch, iter=step, needs_validation=needs_validation)

        self._save_checkpoint_with_step(runner, step, meta=meta)

    def after_val_epoch(self, runner, metrics):
        """Save the checkpoint and synchronize buffers after each evaluation
        epoch.

        Args:
            runner (Runner): The runner of the training process.
            metrics (dict): Evaluation results of all metrics
        """
        if len(metrics) == 0:
            runner.logger.warning(
                "Since `metrics` is an empty dict, the behavior to save "
                "the best checkpoint will be skipped in this evaluation."
            )
            return

        self._save_best_checkpoint(runner, metrics)

        # This hook called in ValLoop::run() which is called from EpochBasedTrainLoop::run() or
        #   IterBasedTrainLoop::run() after returning from run_epoch() or run_iter() which
        #   increments the epoch or iter. Thus we don't want to save checkpoint with epoch + 1 or
        #   iter + 1, but the current value runner.epoch or runner.iter, so we should NOT call
        #    self._save_checkpoint() but rather call self._save_checkpoint_with_step() explicitly
        step = runner.epoch if self.by_epoch else runner.iter
        meta = dict(epoch=runner.epoch, iter=runner.iter, needs_validation=False)
        self._save_checkpoint_with_step(runner, step, meta=meta)

and here's the CustomRunner

from mmengine.runner import Runner
from mmengine.registry import RUNNERS

@RUNNERS.register_module()
class CustomRunner(Runner):

    @property
    def resumed_checkpoint(self) -> Optional[str]:
        return self._resumed_checkpoint

    def resume(
        self,
        filename: str,
        resume_optimizer: bool = True,
        resume_param_scheduler: bool = True,
        map_location: Union[str, Callable] = "default",
    ) -> None:
        # Store resumed checkpoint filename so we can check it in our pre-train validation hook
        self._resumed_checkpoint = filename
        return super().resume(
            filename=filename,
            resume_optimizer=resume_optimizer,
            resume_param_scheduler=resume_param_scheduler,
            map_location=map_location,
        )

Please let me know if this is something you would want to turn into a PR, or if you can see a better way of doing this. Thank you.