open-mmlab / mmsegmentation

OpenMMLab Semantic Segmentation Toolbox and Benchmark.
https://mmsegmentation.readthedocs.io/en/main/
Apache License 2.0
7.98k stars 2.57k forks source link

add _iter and _max_iter as arguments to runner #361

Open baibaidj opened 3 years ago

baibaidj commented 3 years ago

Describe the feature

add self._iter and self._max_iter as arguments to the self.model.train_step() in the runner. Maybe like this:

def run_iter(self, data_batch, train_mode, **kwargs):
    kwargs['iter'] = self._iter
    kwargs['max_iter'] = self._max_iters
    if self.batch_processor is not None:
        outputs = self.batch_processor(
            self.model, data_batch, train_mode=train_mode, **kwargs)
    elif train_mode:
        outputs = self.model.train_step(data_batch, self.optimizer, 
                                        **kwargs)
    else:
        outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
    if not isinstance(outputs, dict):
        raise TypeError('"batch_processor()" or "model.train_step()"'
                        'and "model.val_step()" must return a dict')
    if 'log_vars' in outputs:
        self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
    self.outputs = outputs

Motivation The global iteration and max iteration may be needed by the network forward during training and/or validation. Recently, I came across at least two studies that require this. One study is Temporally Distributed Networks for Fast Video Semantic Segmentation (CVPR'20) and another is Regularizing Deep Networks with Semantic Data Augmentation (PAMI20).

Related resources The following two links demonstrate how those two arguments were used during training. https://github.com/feinanshan/TDNet/blob/3f8b5378fcc7f97c26b3760ddaf3d4402cf477d1/Training/train.py#L118 https://github.com/blackfeather-wang/ISDA-for-Deep-Networks/blob/318c30976d0c412a7dd10250b0164beac6d4fbeb/Semantic%20segmentation%20on%20Cityscapes/train_isda.py#L363

Additional context I was able to implement the implicit semantic data augmentation in mmseg and designed a workaround to add those two arguments in the model training like follows, where every time the self.forward_train is called the self._iters will be updated by adding one.

class FCNHead(BaseDecodeHead):
        #is_use_isda (boo): if use implicit semantic data augmentation
        #isda_lambda (float) : 'The hyper-parameter \lambda_0 for ISDA, select from {1, 2.5, 5, 7.5, 10}. '
    def __init__(self,
                 num_convs=2,
                 kernel_size=3,
                 concat_input=True,
                 is_use_isda = False, 
                 isda_lambda = 2.5,
                 start_iters = 1,
                 max_iters = 4e5,
                 **kwargs):
        assert isinstance(num_convs, int)
        self.num_convs = num_convs
        self.concat_input = concat_input
        self.kernel_size = kernel_size
        super(FCNHead3D, self).__init__(**kwargs)
        convs = []
        convs.append(
            ConvModule(
                self.in_channels,
                self.channels,
                kernel_size=kernel_size,
                padding=kernel_size // 2,
                conv_cfg=self.conv_cfg,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg))
        for _ in range(num_convs - 1):
            convs.append(
                ConvModule(
                    self.channels,
                    self.channels,
                    kernel_size=kernel_size,
                    padding=kernel_size // 2,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg,
                    act_cfg=self.act_cfg))
        if num_convs == 0:
            self.convs = nn.Identity()
        else:
            self.convs = nn.Sequential(*convs)
        if self.concat_input:
            self.conv_cat = ConvModule(
                self.in_channels + self.channels,
                self.channels,
                kernel_size=kernel_size,
                padding=kernel_size // 2,
                conv_cfg=self.conv_cfg,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg)

        self.is_use_isda = is_use_isda
        self.isda_lambda = isda_lambda
        self._iter = start_iters
        self._max_iters = max_iters
        if is_use_isda:
            self.isda_augmentor = ISDALoss(self.final_channel, self.num_classes)

    def forward(self, inputs):
        # ratio = args.lambda_0 * global_iteration / args.num_steps # training progress as percentage 
        x = self._transform_inputs(inputs)
        feat_map = self.convs(x) if self.num_convs > 0 else x
        if self.concat_input:
            feat_map = self.conv_cat(torch.cat([x, feat_map], dim=1))
        output = self.cls_seg(feat_map)

        if self.is_use_isda and self.training :
            return output, feat_map.detach()
        else:
            return output

    def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
        if self.is_use_isda:
            ratio = self.isda_lambda * self._iter / self._max_iters # training progress as percentage 
            seg_logits_1, last_feat_map = self.forward(inputs)
            # pdb.set_trace()
            seg_logits = self.isda_augmentor(last_feat_map, self.conv_seg, seg_logits_1, gt_semantic_seg, ratio) #
            self._iter += 1  
        else:
            seg_logits = self.forward(inputs)
        losses = self.losses(seg_logits, gt_semantic_seg)
        return losses

But this is not the ultimate solution. I plan to implement the TDNet using mmseg in the future and may also encounter this issue.

xvjiarui commented 3 years ago

Hi @baibaidj Thanks for the suggestion. Unfortunately, the training flow and network are separated currently. We may consider this point and do some refactoring in the future.

jianlong-yuan commented 2 years ago

any news?