facebookresearch / unbiased-teacher

PyTorch code for ICLR 2021 paper Unbiased Teacher for Semi-Supervised Object Detection
https://arxiv.org/abs/2102.09480
MIT License
415 stars 83 forks source link

a question about the code: `TEST.VAL_LOSS` #12

Closed xiaohu2015 closed 3 years ago

xiaohu2015 commented 3 years ago

I found that the trainer would do eval on student and teacher model for some EVAL_PERIOD:

ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results_student)) ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results_teacher))

but why you add another eval hook:

        if cfg.TEST.VAL_LOSS:  # default is True # save training time if not applied
            ret.append(
                LossEvalHook(
                    cfg.TEST.EVAL_PERIOD,
                    self.model,
                    build_detection_test_loader(
                        self.cfg,
                        self.cfg.DATASETS.TEST[0],
                        DatasetMapper(self.cfg, True),
                    ),
                    model_output="loss_proposal",
                    model_name="student",
                )
            )

            ret.append(
                LossEvalHook(
                    cfg.TEST.EVAL_PERIOD,
                    self.model_teacher,
                    build_detection_test_loader(
                        self.cfg,
                        self.cfg.DATASETS.TEST[0],
                        DatasetMapper(self.cfg, True),
                    ),
                    model_output="loss_proposal",
                    model_name="",
                )
            )
vinkle commented 3 years ago

The first two hooks are used to calculate the AP on the coco validation set for student and teacher model. The hooks inside the "cfg.TEST.VAL_LOSS" condition are used to calculate the losses on the coco validation set (RPN classification/regression loss, BBox classification/regression loss).

ycliu93 commented 3 years ago

Yes, it is what @vinkle said (thanks @vinkle!), while it is a hacky way to compute validation loss. It takes additional model inference to compute validation losses, so it is not that efficient.

You could set cfg.TEST.VAL_LOSS = False to speed up the training if you do not need the validation loss.