Open collinmccarthy opened 11 months ago
I forgot this is how I did it before.
needs_val=True
, then after validation save a new checkpoint with needs_val=False
, and check for this flag when resuming from a checkpoint
I was able to get this working with approach 3 above but it required a few different changes:
This is what it ended up looking like. Here's the PreTrainValHook
from mmengine.registry import HOOKS
from mmengine.hooks import Hook
from ..runner import CustomRunner
from ..runner.loops import PreTrainValLoop
@HOOKS.register_module()
class PreTrainValHook(Hook):
"""Run validation before training begins.
When resuming on an epoch where validation should have just run at the end of the previous
epoch, we check if validation succeeded and if not (due to preemption or a bug) we re-run it.
- Uses ..engine.hooks.CustomCheckpointHook to store the "needs_validation" metadata
- Uses ..engine.runner.CustomRunner to store resumed_checkpoint filepath
- See https://github.com/open-mmlab/mmengine/issues/1394
"""
def __init__(self, partial_num_batches: int = 2):
self._partial_num_batches = partial_num_batches
super().__init__()
def before_train(self, runner: CustomRunner):
# Calling runner.val_loop will build ValLoop, and in turn, build runner.val_loop.dataloader
# by calling runner.build_dataloader() with the validation dataloader config
val_loop = runner.val_loop
# Check whether we need to do a full validation loop (from resuming after preemption)
# or a partial validation loop to test model/dataloader/evaluator
full_validation = False
if runner.resumed_checkpoint is not None:
# Load checkpoint to CPU since we just need to check the metadata
checkpoint = runner.load_checkpoint(
filename=runner.resumed_checkpoint, map_location="cpu"
)
full_validation = checkpoint["meta"].get("needs_validation", False)
if full_validation:
val_loop.run()
and here's the CustomCheckpointHook
from typing import Optional, Union
from mmengine.runner.loops import IterBasedTrainLoop, EpochBasedTrainLoop
from mmengine.hooks import CheckpointHook
from mmengine.registry import HOOKS
DATA_BATCH = Optional[Union[dict, tuple, list]]
@HOOKS.register_module()
class CustomCheckpointHook(CheckpointHook):
def _save_checkpoint(self, runner, maybe_needs_validation=True) -> None:
"""Save the checkpoint and add "needs_validation" metadata if calling before validation.
Called from CheckpointHook.after_train_iter() and CheckpointHook.after_train_epoch().
Must pass in maybe_needs_validation=False to save another resume checkpoint after val ends
and indicate that validation has finished.
"""
needs_validation = False
if runner.val_loop is not None and maybe_needs_validation:
if isinstance(runner.train_loop, IterBasedTrainLoop):
# This hook called in IterBasedTrainLoop::run_iter() via 'after_train_train_iter'
# Then iter is incremented by one, and IterBasedTrainLoop::run() calls validation
# Same logic as in IterBasedTrainLoop::run() but with iter+1 since we're testing
# if we will need validation next iter (but we haven't incremented iter yet)
next_iter = runner.iter + 1
needs_validation = (
runner.val_loop is not None
and next_iter >= runner.val_begin
and next_iter % runner.val_interval == 0
)
elif isinstance(runner.train_loop, EpochBasedTrainLoop):
# This hook called in EpochBasedTrainLoop::run_epoch() via 'after_train_epoch'
# Then epoch is incremented by one, and EpochBasedTrainLoop::run() calls validation
# Same logic as in EpochBasedTrainLoop::run() but with epoch+1 since we're testing
# if we will need validation next epoch (but we haven't incremented epoch yet)
next_epoch = runner.epoch + 1
needs_validation = (
runner.val_loop is not None
and next_epoch >= runner.val_begin
and next_epoch % runner.val_interval == 0
)
else:
raise RuntimeError(
f"Expected type(runner.train_loop) in [IterBasedTrainLoop, EpochBasedTrainLoop]"
f", found {type(runner.train_loop)}"
)
if self.by_epoch:
step = runner.epoch + 1
meta = dict(epoch=step, iter=runner.iter, needs_validation=needs_validation)
else:
step = runner.iter + 1
meta = dict(epoch=runner.epoch, iter=step, needs_validation=needs_validation)
self._save_checkpoint_with_step(runner, step, meta=meta)
def after_val_epoch(self, runner, metrics):
"""Save the checkpoint and synchronize buffers after each evaluation
epoch.
Args:
runner (Runner): The runner of the training process.
metrics (dict): Evaluation results of all metrics
"""
if len(metrics) == 0:
runner.logger.warning(
"Since `metrics` is an empty dict, the behavior to save "
"the best checkpoint will be skipped in this evaluation."
)
return
self._save_best_checkpoint(runner, metrics)
# This hook called in ValLoop::run() which is called from EpochBasedTrainLoop::run() or
# IterBasedTrainLoop::run() after returning from run_epoch() or run_iter() which
# increments the epoch or iter. Thus we don't want to save checkpoint with epoch + 1 or
# iter + 1, but the current value runner.epoch or runner.iter, so we should NOT call
# self._save_checkpoint() but rather call self._save_checkpoint_with_step() explicitly
step = runner.epoch if self.by_epoch else runner.iter
meta = dict(epoch=runner.epoch, iter=runner.iter, needs_validation=False)
self._save_checkpoint_with_step(runner, step, meta=meta)
and here's the CustomRunner
from mmengine.runner import Runner
from mmengine.registry import RUNNERS
@RUNNERS.register_module()
class CustomRunner(Runner):
@property
def resumed_checkpoint(self) -> Optional[str]:
return self._resumed_checkpoint
def resume(
self,
filename: str,
resume_optimizer: bool = True,
resume_param_scheduler: bool = True,
map_location: Union[str, Callable] = "default",
) -> None:
# Store resumed checkpoint filename so we can check it in our pre-train validation hook
self._resumed_checkpoint = filename
return super().resume(
filename=filename,
resume_optimizer=resume_optimizer,
resume_param_scheduler=resume_param_scheduler,
map_location=map_location,
)
Please let me know if this is something you would want to turn into a PR, or if you can see a better way of doing this. Thank you.
What is the feature?
Suppose you're training with checkpointing. You save a checkpoint every 10 epochs and run validation every 10 epochs. After the 10th epoch you save a checkpoint before validation, start your validation loop, and the program crashes. Maybe you got preempted on a cluster, maybe you ran out of memory, maybe there was annotation issue, etc. Then you fix the bug if there was one and restart, but the validation loop gets skipped and training resumes on the 11th epoch.
I don't think this is a good idea, and I think it could be prevented by restructuring how/when the validation loop or check occurs. If validation crashes, when you resume you should automatically re-run validation.
Is this already implemented somewhere/somehow? If not, do you have any recommendations about how to go about it? I've hacked together something similar before in Detectron2 but I'd rather do it the "right way" here.
Any other context?
If it's not already implemented I thought about the following:
results.json
or similar