facebookresearch / detectron2

Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.
https://detectron2.readthedocs.io/en/latest/
Apache License 2.0
30.49k stars 7.48k forks source link

How do I compute validation loss during training? #810

Closed tshead2 closed 4 years ago

tshead2 commented 4 years ago

How do I compute validation loss during training?

I'm trying to compute the loss on a validation dataset for each iteration during training. To do so, I've created my own hook:

class ValidationLoss(detectron2.engine.HookBase):
    def __init__(self, config, dataset_name):
        super(ValidationLoss, self).__init__()
        self._loader = detectron2.data.build_detection_test_loader(config, dataset_name)

    def after_step(self):
        for batch in self._loader:
            loss = self.trainer.model(batch)
            log.debug(f"validation loss: {loss}")

... which I register with a DefaultTrainer. The hook code is called during training, but fails with the following:

INFO:detectron2.engine.train_loop:Starting training from iteration 0
ERROR:detectron2.engine.train_loop:Exception during training:
Traceback (most recent call last):
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 133, in train
    self.after_step()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 153, in after_step
    h.after_step()
  File "<ipython-input-6-63b308743b7d>", line 8, in after_step
    loss = self.trainer.model(batch)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/modeling/meta_arch/rcnn.py", line 123, in forward
    proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/modeling/proposal_generator/rpn.py", line 164, in forward
    losses = {k: v * self.loss_weight for k, v in outputs.losses().items()}
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/modeling/proposal_generator/rpn_outputs.py", line 322, in losses
    gt_objectness_logits, gt_anchor_deltas = self._get_ground_truth()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/modeling/proposal_generator/rpn_outputs.py", line 262, in _get_ground_truth
    for image_size_i, anchors_i, gt_boxes_i in zip(self.image_sizes, anchors, self.gt_boxes):
TypeError: zip argument #3 must support iteration
INFO:detectron2.engine.hooks:Total training time: 0:00:00 (0:00:00 on hooks)

The traceback seems to imply that ground truth data is missing, which made me think that the data loader was the problem. However, switching to a training loader produces a different error:

class ValidationLoss(detectron2.engine.HookBase):
    def __init__(self, config, dataset_name):
        super(ValidationLoss, self).__init__()
        self._loader = detectron2.data.build_detection_train_loader(config, dataset_name)

    def after_step(self):
        for batch in self._loader:
            loss = self.trainer.model(batch)
            log.debug(f"validation loss: {loss}")
INFO:detectron2.engine.train_loop:Starting training from iteration 0
ERROR:detectron2.engine.train_loop:Exception during training:
Traceback (most recent call last):
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 133, in train
    self.after_step()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/engine/train_loop.py", line 153, in after_step
    h.after_step()
  File "<ipython-input-6-e0d2c509cc72>", line 7, in after_step
    for batch in self._loader:
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/data/common.py", line 109, in __iter__
    for d in self.dataset:
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
    return self._process_data(data)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
    data.reraise()
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/data/common.py", line 39, in __getitem__
    data = self._map_func(self._dataset[cur_idx])
  File "/ascldap/users/tshead/miniconda3/lib/python3.7/site-packages/detectron2/utils/serialize.py", line 23, in __call__
    return self._obj(*args, **kwargs)
TypeError: 'str' object is not callable

INFO:detectron2.engine.hooks:Total training time: 0:00:00 (0:00:00 on hooks)

As a sanity check, inference works just fine:

class ValidationLoss(detectron2.engine.HookBase):
    def __init__(self, config, dataset_name):
        super(ValidationLoss, self).__init__()
        self._loader = detectron2.data.build_detection_test_loader(config, dataset_name)

    def after_step(self):
        for batch in self._loader:
            with detectron2.evaluation.inference_context(self.trainer.model):
                loss = self.trainer.model(batch)
                log.debug(f"validation loss: {loss}")
INFO:detectron2.engine.train_loop:Starting training from iteration 0
DEBUG:root:validation loss: [{'instances': Instances(num_instances=100, image_height=720, image_width=720, fields=[pred_boxes = Boxes(tensor([[4.4867e+02, 1.9488e+02, 5.1496e+02, 3.9878e+02],
        [4.2163e+02, 1.1204e+02, 6.1118e+02, 5.5378e+02],
        [8.7323e-01, 3.0374e+02, 9.2917e+01, 3.8698e+02],
        [4.3202e+02, 2.0296e+02, 5.7938e+02, 3.6817e+02],
        ...

... but that isn't what I want, of course. Any thoughts?

Thanks in advance, Tim

edoardounali commented 1 year ago

I extended the code above to log both the train and val loss in the same graph in tensorboard. I put it here because i think it could be useful for others ending up here.

This is what your TB log will look like eventually image

To do this, first create a custom tensorboard writer:

import os
from torch.utils.tensorboard import SummaryWriter
from detectron2.utils.events import EventWriter, get_event_storage

class CustomTensorboardXWriter(EventWriter):
    """
    Writes scalars and images based on storage key to train or val tensorboard file.
    """

    def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
        """
        Args:
            log_dir (str): the base directory to save the output events. This class creates two subdirs in log_dir
            window_size (int): the scalars will be median-smoothed by this window size

            kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
        """
        self._window_size = window_size

        # separate the writers into a train and a val writer
        train_writer_path = os.path.join(log_dir,"train")
        os.makedirs(train_writer_path, exist_ok=True)
        self._writer_train = SummaryWriter(train_writer_path, **kwargs)

        val_writer_path = os.path.join(log_dir,"val")
        os.makedirs(val_writer_path, exist_ok=True)
        self._writer_val = SummaryWriter(val_writer_path, **kwargs)

    def write(self):

        storage = get_event_storage()
        for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
            if k.startswith("val_"):
                k = k.replace("val_","")
                self._writer_val.add_scalar(k, v, iter)
            else:
                self._writer_train.add_scalar(k, v, iter)

        if len(storage._vis_data) >= 1:
            for img_name, img, step_num in storage._vis_data:
                if k.startswith("val_"):
                    k = k.replace("val_","")
                    self._writer_val.add_image(img_name, img, step_num)
                else:
                    self._writer_train.add_image(img_name, img, step_num)
            # Storage stores all image data and rely on this writer to clear them.
            # As a result it assumes only one writer will use its image data.
            # An alternative design is to let storage store limited recent
            # data (e.g. only the most recent image) that all writers can access.
            # In that case a writer may not see all image data if its period is long.
            storage.clear_images()

        if len(storage._histograms) >= 1:
            for params in storage._histograms:
                self._writer_train.add_histogram_raw(**params)
            storage.clear_histograms()

    def close(self):
        if hasattr(self, "_writer"):  # doesn't exist when the code fails at import
            self._writer_train.close()
            self._writer_val.close()

Then register this writer in your trainer. It will write plot train and val metrics in the same graph

class Trainer(DefaultTrainer):
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR,"inference")
        return COCOEvaluator(dataset_name, cfg, True, output_folder)

    def build_writers(self):
        """
        Overwrites the default writers to contain our custom tensorboard writer

        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.
        """
        return [
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            CustomTensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

Hi, I followed every step. But writer saves vall loss only every period defined by

cfg.TEST.EVAL_PERIOD = 50

and not every 20 steps ( windows_size).

Furthermore tensorboard scalars are plotted on two different graphs, not togher undet total_loss.

ravijo commented 1 year ago

@edoardounali

tensorboard scalars are plotted on two different graphs, not togher undet total_loss.

For this issue, you may check the following working demo: detectron2_tutorial

edoardounali commented 1 year ago

@edoardounali

tensorboard scalars are plotted on two different graphs, not togher undet total_loss.

For this issue, you may check the following working demo: detectron2_tutorial

Solved! Thanks you so much!

CA4GitHub commented 9 months ago

@ravijo is there a reason your ValLossHook class after_step method doesn't loop over the batches returned by the data loader (i.e. self._loader)? I expected the method to loop over all the batches and compute an average or total loss.

CHN-001 commented 4 months ago

Question1: File "/home/server/anaconda3/envs/bcnet/lib/python3.7/site-packages/yacs/config.py", line 147, in setattr name, value AttributeError: Attempted to set TRAIN to ('coco_my_val',), but CfgNode is immutable

Solution: I solved it by modify code:

class ValidationLoss(HookBase): def init(self, cfg): super().init() self.cfg = cfg.clone() self.cfg.defrost() # Unfreeze the config to allow modifications self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL self.cfg.freeze() # Refreeze the config after modifications self._loader = iter(build_detection_train_loader(self.cfg))

Question2: File "/home/server/anaconda3/envs/bcnet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, **kwargs) TypeError: forward() missing 2 required positional arguments: 'c_iter' and 'max_iter'

Solution: I solved it by modify code:

class ValidationLoss(HookBase): def init(self, cfg): super().init() self.cfg = cfg.clone() self.cfg.defrost() # Unfreeze the config to allow modifications self.cfg.DATASETS.TRAIN = cfg.DATASETS.VAL self.cfg.freeze() # Refreeze the config after modifications self._loader = iter(build_detection_train_loader(self.cfg))

def after_step(self):
    data = next(self._loader)
    with torch.no_grad():
        c_iter = self.trainer.iter
        max_iter = self.trainer.max_iter # 获取当前迭代数和最大迭代数

        # 调用模型时传递 c_iter 和 max_iter 参数
        loss_dict = self.trainer.model(data, c_iter=c_iter, max_iter=max_iter)

        losses = sum(loss_dict.values())
        assert torch.isfinite(losses).all(), loss_dict

        loss_dict_reduced = {"val_" + k: v.item() for k, v in 
                             comm.reduce_dict(loss_dict).items()}
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())
        if comm.is_main_process():
            self.trainer.storage.put_scalars(total_val_loss=losses_reduced, 
                                             **loss_dict_reduced)
ahmadsadeed commented 1 month ago

Hi Alono, nice that you answer, I was about to research about wesleylp question. I have only tried it con single gpu, no idea what changes would multiple cpu imply, sorry

Hi I tried your code but after running validation it just hangs and does not run anything else. Please help me. Thank you very much. After a while, an error popped up: RuntimeError: [/opt/conda/conda-bld/pytorch_1587428207430/work/third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:136] Timed out waiting 1800000ms for send operation to complete

Did you find the solution to this issue?

Fix to the issue.

def build_hooks(self):
    hooks = super().build_hooks()
    hooks.insert(-1,LossEvalHook(
        cfg.TEST.EVAL_PERIOD,
        self.model,
        build_detection_test_loader(
            self.cfg,
            self.cfg.DATASETS.TEST[0],
            DatasetMapper(self.cfg,True)
        )
    ))
    # swap the order of PeriodicWriter and ValidationLoss
    # code hangs with no GPUs > 1 if this line is removed
    hooks = hooks[:-2] + hooks[-2:][::-1]
    return hooks

Why not use hooks.append() and won't need swapping ?