open-mmlab / mmdetection

OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io
Apache License 2.0
29.06k stars 9.37k forks source link

How to use hook to get the validation loss? #11331

Open Zhong-Zi-Zeng opened 8 months ago

Zhong-Zi-Zeng commented 8 months ago

Recently, I have implemented a simple method that wants to get the validation loss. The code below

@HOOKS.register_module()
class MyHook(Hook):
    def __init__(self):
        pass

    def after_train_epoch(self, runner) -> None:
        model = runner.model
        model.eval()

        for i, data in enumerate(runner.val_dataloader):
            outputs = runner.model.train_step(data, runner.optim_wrapper)

Although, it can correctly load the validation data, I get this error message

Traceback (most recent call last):
  File "/home/miislab-server2/Heng/Heng_shared/AOI-Project/model_zoo/mmdetection/tools/train.py", line 122, in <module>
    main()
  File "/home/miislab-server2/Heng/Heng_shared/AOI-Project/model_zoo/mmdetection/tools/train.py", line 118, in main
    runner.train()
  File "/home/miislab-server2/Heng/envs/AOI/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1745, in train
    self.call_hook('before_run')
  File "/home/miislab-server2/Heng/envs/AOI/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1839, in call_hook
    getattr(hook, fn_name)(self, **kwargs)
  File "/home/miislab-server2/Heng/Heng_shared/AOI-Project/model_zoo/mmdetection/mmdet/engine/hooks/my_hook.py", line 20, in before_run
    outputs = runner.model.train_step(data, runner.optim_wrapper)
  File "/home/miislab-server2/Heng/envs/AOI/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 113, in train_step
    data = self.data_preprocessor(data, True)
  File "/home/miislab-server2/Heng/envs/AOI/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/miislab-server2/Heng/envs/AOI/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/miislab-server2/Heng/Heng_shared/AOI-Project/model_zoo/mmdetection/mmdet/models/data_preprocessors/data_preprocessor.py", line 140, in forward
    self.pad_gt_masks(data_samples)
  File "/home/miislab-server2/Heng/Heng_shared/AOI-Project/model_zoo/mmdetection/mmdet/models/data_preprocessors/data_preprocessor.py", line 191, in pad_gt_masks
    data_samples.gt_instances.masks = masks.pad(
  File "/home/miislab-server2/Heng/Heng_shared/AOI-Project/model_zoo/mmdetection/mmdet/structures/mask/structures.py", line 346, in pad
    return BitmapMasks(padded_masks, *out_shape)
  File "/home/miislab-server2/Heng/Heng_shared/AOI-Project/model_zoo/mmdetection/mmdet/structures/mask/structures.py", line 268, in __init__
    self.masks = np.stack(masks).reshape(-1, height, width)
ValueError: cannot reshape array of size 12582912 into shape (704,1024)

So, I printed the height and width, I found that it would change every time. But when I try to run the below code it can work correctly

@HOOKS.register_module()
class MyHook(Hook):
    def __init__(self):
        pass

    def after_train_epoch(self, runner) -> None:
        model = runner.model
        model.eval()

        for i, data in enumerate(runner.train_dataloader):
            outputs = runner.model.train_step(data, runner.optim_wrapper)

Could someone help me with that, I will appreciate.

g824718114 commented 4 months ago

Hey, I had the same problem as you, here's my solution

First, a new hook needs to be defined. Here, I redefined the after_test_iter in LoggerHook and RuntimeInfoHook to use the test_dataloader to validate the model after each traing epoch.

@HOOKS.register_module()
class MyHook(Hook):
    def __init__(self):
        pass

    def val_step(self, model, data, optim_wrapper):
        with optim_wrapper.optim_context(model):
            data = model.data_preprocessor(data, True)
            losses = model(**data, mode='loss')  # type: ignore
        parsed_losses, log_vars = model.parse_losses(losses)  
        return log_vars

    def after_train_epoch(self, runner) -> None:
        model = runner.model
        model.train()
        optim_wrapper = runner.optim_wrapper
        dataloader = runner.test_dataloader
        for hook in runner._hooks:
            if isinstance(hook,(LoggerHook,)):
                logger = hook
            elif isinstance(hook,(RuntimeInfoHook,)):
                runtimeinfo = hook

        for i, data in enumerate(dataloader):
            outputs = self.val_step(model, data, optim_wrapper)
            getattr(runtimeinfo, 'after_test_iter')(runner, None, None, outputs)
            getattr(logger, 'after_test_iter')(runner,i+1)

Next, I modified the files under path /root/anaconda3/lib/python3.11/site-packages/mmengine/hooks/logger_hook.py and /root/anaconda3/lib/python3.11/site-packages/mmengine/hooks/runtime_info_hook.py For the LoggerHook, add this code

def after_test_iter(self,
                         runner,
                         batch_idx: int,
                         data_batch: DATA_BATCH = None,
                         outputs: Optional[dict] = None) -> None:
        """Record logs after training iteration.

        Args:
            runner (Runner): The runner of the training process.
            batch_idx (int): The index of the current batch in the train loop.
            data_batch (dict tuple or list, optional): Data from dataloader.
            outputs (dict, optional): Outputs from model.
        """
        # Print experiment name every n iterations.
        if self.every_n_train_iters(
                runner, self.interval_exp_name) or (self.end_of_epoch(
                    runner.test_dataloader, batch_idx)):
            exp_info = f'Exp name: {runner.experiment_name}'
            runner.logger.info(exp_info)
        if self.every_n_inner_iters(batch_idx, self.interval):
            tag, log_str = runner.log_processor.get_log_after_iter(
                runner, batch_idx, 'test')
        elif (self.end_of_epoch(runner.test_dataloader, batch_idx)
              and (not self.ignore_last
                   or len(runner.test_dataloader) <= self.interval)):
            # `runner.max_iters` may not be divisible by `self.interval`. if
            # `self.ignore_last==True`, the log of remaining iterations will
            # be recorded (Epoch [4][1000/1007], the logs of 998-1007
            # iterations will be recorded).
            tag, log_str = runner.log_processor.get_log_after_iter(
                runner, batch_idx, 'test')
        else:
            return
        runner.logger.info(log_str)
        runner.visualizer.add_scalars(
            tag, step=runner.iter + 1, file_path=self.json_log_path)

For the RuntimeInfoHook, add this code

 def after_test_iter(self,
                         runner,
                         batch_idx: int,
                         data_batch: DATA_BATCH = None,
                         outputs: Optional[dict] = None) -> None:
        """Update ``log_vars`` in model outputs every iteration.

        Args:
            runner (Runner): The runner of the training process.
            batch_idx (int): The index of the current batch in the train loop.
            data_batch (Sequence[dict], optional): Data from dataloader.
                Defaults to None.
            outputs (dict, optional): Outputs from model. Defaults to None.
        """
        if outputs is not None:
            for key, value in outputs.items():
                runner.message_hub.update_scalar(f'test/{key}', value)

Then, change something about test_dataloader in the config file

costum_hooks = [
    ...
    dict(type='MyHook'),
    ...
]
test_dataloader= dict(
    ...
    collate_fn=dict(type='yolov5_collate'),
    pipeline=[
        ...
        dict(type='LoadAnnotations', with_bbox=True),
        ...
        ],
    ...
)

Finally, run train.py, and you can see the validation loss in the results, and it will also be saved in the log in json format 微信截图_20240422231502 微信截图_20240422231555

radish512 commented 3 months ago

@g824718114 Thank you for providing the method, but may I ask for some modifications to the test_dataloader in the configuration file Collate_fn=dict (type='yolov5_collate ') Where is this defined? Because I haven't used YOLOv5, or is this a universal modification method?

g824718114 commented 3 months ago

I'm sorry that I didn't make it clear. This is not a universal modification. The collate_fn in test_dataloader need to be the same as train_dataloader.

---Original--- From: @.> Date: Tue, Jun 11, 2024 09:42 AM To: @.>; Cc: "Guo @.**@.>; Subject: Re: [open-mmlab/mmdetection] How to use hook to get the validationloss? (Issue #11331)

@g824718114 Thank you for providing the method, but may I ask for some modifications to the test_dataloader in the configuration file Collate_fn=dict (type='yolov5_collate ') Where is this defined? Because I haven't used YOLOv5, or is this a universal modification method?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.Message ID: @.***>