dywu98 / CBL-Conditional-Boundary-Loss

The official implementation of IEEE-TIP paper under review
12 stars 3 forks source link

请问可以提供基于mmseg的配置文件吗 #1

Open euphonium1998 opened 9 months ago

euphonium1998 commented 9 months ago

请问可以提供基于mmseg的配置文件吗?目前我按照md的指引配好之后。但是mmseg一直跑不起来。

dywu98 commented 9 months ago

请问具体是什么问题呢?可以提供一下详细的报错信息供参考

dywu98 commented 9 months ago

鉴于没有报错信息的参考,我猜测可能是training pipeline中没有指定生成边界GT,因此decoder中计算CBL loss时缺少boundary GT 导致无法计算

如果确实是以上问题,可以尝试在config中修改training pipeline为以下流程(ADE20K数据集的例子):

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', reduce_zero_label=True),
    dict(type='GenerateBoundary', dilation=0.02),
    dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_boundary_seg'])
]

请将dict(type='GenerateBoundary', dilation=0.02),添加至LoadAnnotations与Resize之间

如有其他问题,欢迎邮件dongyue_wu@hust.edu.cn询问

euphonium1998 commented 9 months ago

您好作者,这部分是最后的报错信息:

Traceback (most recent call last):
  File "tools/train.py", line 242, in <module>
    main()
  File "tools/train.py", line 231, in main
    train_segmentor(
  File "/data/cgc-segmentation/mmsegv030/mmsegmentation/mmseg/apis/train.py", line 194, in train_segmentor
    runner.run(data_loaders, cfg.workflow)
  File "/home/muxiangyu/miniconda3/envs/mmseg-old/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 144, in run
    iter_runner(iter_loaders[i], **kwargs)
  File "/home/muxiangyu/miniconda3/envs/mmseg-old/lib/python3.8/site-packages/mmcv/runner/iter_based_runner.py", line 64, in train
    outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
  File "/home/muxiangyu/miniconda3/envs/mmseg-old/lib/python3.8/site-packages/mmcv/parallel/data_parallel.py", line 77, in train_step
    return self.module.train_step(*inputs[0], **kwargs[0])
  File "/data/cgc-segmentation/mmsegv030/mmsegmentation/mmseg/models/segmentors/base.py", line 138, in train_step
    losses = self(**data_batch)
  File "/home/muxiangyu/miniconda3/envs/mmseg-old/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/muxiangyu/miniconda3/envs/mmseg-old/lib/python3.8/site-packages/mmcv/runner/fp16_utils.py", line 119, in new_func
    return old_func(*args, **kwargs)
  File "/data/cgc-segmentation/mmsegv030/mmsegmentation/mmseg/models/segmentors/base.py", line 108, in forward
    return self.forward_train(img, img_metas, **kwargs)
TypeError: forward_train() got an unexpected keyword argument 'gt_boundary_seg'

这部分是我的pipeline

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='GenerateBoundary', dilation=0.02),
    dict(type='Resize', img_scale=(1024, 768), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.0),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_boundary_seg']),
]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1024, 768),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip', prob=0.0),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]

由于您未提供一个demo的config,我是从ocrnet的config上进行修改的。选择的mmseg版本为v0.30.0

# model settings
norm_cfg = dict(type='BN', requires_grad=True)
model = dict(
    type='CascadeEncoderDecoder',
    num_stages=2,
    pretrained='open-mmlab://resnet50_v1c',
    backbone=dict(
        type='ResNetV1c',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        dilations=(1, 1, 2, 4),
        strides=(1, 2, 1, 1),
        norm_cfg=norm_cfg,
        norm_eval=False,
        style='pytorch',
        contract_dilation=True),
    decode_head=[
        dict(
            type='FCNHead',
            in_channels=1024,
            in_index=2,
            channels=256,
            num_convs=1,
            concat_input=False,
            dropout_ratio=0.1,
            num_classes=6,
            norm_cfg=norm_cfg,
            align_corners=False,
            loss_decode=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
        dict(
            type='New_ER5OCRHead',
            in_channels=2048,
            in_index=3,
            channels=512,
            ocr_channels=256,
            dropout_ratio=0.1,
            num_classes=6,
            norm_cfg=norm_cfg,
            align_corners=False,
            loss_decode=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
    ],
    # model training and testing settings
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))

非常感谢您的帮助!请问是否可以提供一些经典分割网络的config来让读者更好复现呢?

dywu98 commented 9 months ago

抱歉回复不及时~

在尝试完善本工作在MMseg上运行的instruction时,发现mmsegmentation已经进行了许多更新,当前的instruction已经不能适应现在的mmsegmentation。因此,如果您还需要先前版本的CBL+mmsegmentation代码,麻烦您邮件联系我,我可以发送给您未经系统整理的CBL+mmsegmentation repo。我们计划未来更新基于1.2及以上版本的mmsegmentation的CBL代码,敬请期待!

另外,对于您所需要的config,这里可以给出OCRNet在Cityscapes上训练的完整config文件:

norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
    type='CascadeEncoderDecoder',
    num_stages=2,
    pretrained='open-mmlab://msra/hrnetv2_w48',
    backbone=dict(
        type='HRNet',
        norm_cfg=dict(type='SyncBN', requires_grad=True),
        norm_eval=False,
        extra=dict(
            stage1=dict(
                num_modules=1,
                num_branches=1,
                block='BOTTLENECK',
                num_blocks=(4, ),
                num_channels=(64, )),
            stage2=dict(
                num_modules=1,
                num_branches=2,
                block='BASIC',
                num_blocks=(4, 4),
                num_channels=(48, 96)),
            stage3=dict(
                num_modules=4,
                num_branches=3,
                block='BASIC',
                num_blocks=(4, 4, 4),
                num_channels=(48, 96, 192)),
            stage4=dict(
                num_modules=3,
                num_branches=4,
                block='BASIC',
                num_blocks=(4, 4, 4, 4),
                num_channels=(48, 96, 192, 384)))),
    decode_head=[
        dict(
            type='FCNHead',
            in_channels=[48, 96, 192, 384],
            channels=720,
            input_transform='resize_concat',
            in_index=(0, 1, 2, 3),
            kernel_size=1,
            num_convs=1,
            norm_cfg=dict(type='SyncBN', requires_grad=True),
            concat_input=False,
            dropout_ratio=-1,
            num_classes=19,
            align_corners=False,
            loss_decode=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
        dict(
            type='ER5OCRHead',
            in_channels=[48, 96, 192, 384],
            channels=512,
            ocr_channels=256,
            input_transform='resize_concat',
            in_index=(0, 1, 2, 3),
            norm_cfg=dict(type='SyncBN', requires_grad=True),
            dropout_ratio=-1,
            num_classes=19,
            align_corners=False,
            loss_decode=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
    ],
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))
dataset_type = 'CityscapesDataset'
data_root = 'data/cityscapes/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 1024)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='GenerateBoundary', dilation=0.005),
    dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=(512, 1024), cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='Pad', size=(512, 1024), pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_boundary_seg'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(2048, 1024),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type='CityscapesDataset',
        data_root='data/cityscapes/',
        img_dir='leftImg8bit/train',
        ann_dir='gtFine/train',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations'),
            dict(type='GenerateBoundary', dilation=0.005),
            dict(
                type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
            dict(type='RandomCrop', crop_size=(512, 1024), cat_max_ratio=0.75),
            dict(type='RandomFlip', prob=0.5),
            dict(type='PhotoMetricDistortion'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='Pad', size=(512, 1024), pad_val=0, seg_pad_val=255),
            dict(type='DefaultFormatBundle'),
            dict(
                type='Collect',
                keys=['img', 'gt_semantic_seg', 'gt_boundary_seg'])
        ]),
    val=dict(
        type='CityscapesDataset',
        data_root='data/cityscapes/',
        img_dir='leftImg8bit/val',
        ann_dir='gtFine/val',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(2048, 1024),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]),
    test=dict(
        type='CityscapesDataset',
        data_root='data/cityscapes/',
        img_dir='leftImg8bit/val',
        ann_dir='gtFine/val',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(2048,1024),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]))
log_config = dict(
    interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
lr_config = dict(policy='poly', power=0.9, min_lr=0.0001, by_epoch=False)
runner = dict(type='IterBasedRunner', max_iters=80000)
checkpoint_config = dict(by_epoch=False, interval=4000)
evaluation = dict(interval=80000, metric='mIoU')
work_dir = 'work_dirs/biou/ER5_shixin_OCR_HRNet_80k_feat/'
gpu_ids = range(0, 1)

另外,按照最容易在mmseg中使用CBL的原则,在其他经典分割网络上使用CBL的最简单方式是仿照我们提供的CBLocr_head.py修改其他Decoder。

例如,可以写出如下CBLuper_head.py (CBL+UperNet,用于复现CBL+Swin):


import numpy as np

import torch.nn.functional as F
from ..losses import accuracy
from ..builder import build_loss

from mmcv.runner import auto_fp16, force_fp32

class NeighborExtractor5(nn.Module):
    def __init__(self, input_channel):
        super(NeighborExtractor5, self).__init__()
        same_class_neighbor = np.array([[1, 1, 1, 1, 1], 
                                        [1, 1, 1, 1, 1], 
                                        [1, 1, 0, 1, 1], 
                                        [1, 1, 1, 1, 1],
                                        [1, 1, 1, 1, 1], ], dtype='float32')
        same_class_neighbor = same_class_neighbor.reshape((1, 1, 5, 5))
        same_class_neighbor = np.repeat(same_class_neighbor, input_channel, axis=0)
        self.same_class_extractor = nn.Conv2d(input_channel, input_channel, kernel_size=5, padding=2, bias=False, groups=input_channel)
        self.same_class_extractor.weight.data = torch.from_numpy(same_class_neighbor)

    def forward(self, feat):
        output = self.same_class_extractor(feat)
        return output

@HEADS.register_module()
class ERUPerHead(BaseDecodeHead):
    """Unified Perceptual Parsing for Scene Understanding.

    This head is the implementation of `UPerNet
    <https://arxiv.org/abs/1807.10221>`_.

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module applied on the last feature. Default: (1, 2, 3, 6).
    """

    def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
        super(ERUPerHead, self).__init__(
            input_transform='multiple_select', **kwargs)
        # PSP Module
        self.psp_modules = PPM(
            pool_scales,
            self.in_channels[-1],
            self.channels,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg,
            align_corners=self.align_corners)
        self.bottleneck = ConvModule(
            self.in_channels[-1] + len(pool_scales) * self.channels,
            self.channels,
            3,
            padding=1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)
        # FPN Module
        self.lateral_convs = nn.ModuleList()
        self.fpn_convs = nn.ModuleList()
        for in_channels in self.in_channels[:-1]:  # skip the top layer
            l_conv = ConvModule(
                in_channels,
                self.channels,
                1,
                conv_cfg=self.conv_cfg,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg,
                inplace=False)
            fpn_conv = ConvModule(
                self.channels,
                self.channels,
                3,
                padding=1,
                conv_cfg=self.conv_cfg,
                norm_cfg=self.norm_cfg,
                act_cfg=self.act_cfg,
                inplace=False)
            self.lateral_convs.append(l_conv)
            self.fpn_convs.append(fpn_conv)

        self.fpn_bottleneck = ConvModule(
            len(self.in_channels) * self.channels,
            self.channels,
            3,
            padding=1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)
        base_weight = np.array([[1, 1, 1, 1, 1], 
                                [1, 1, 1, 1, 1], 
                                [1, 1, 0, 1, 1], 
                                [1, 1, 1, 1, 1],
                                [1, 1, 1, 1, 1], ], dtype='float32')
        base_weight = base_weight.reshape((1, 1, 5, 5))
        self.same_class_extractor_weight = np.repeat(base_weight, 512, axis=0)
        self.same_class_extractor_weight = torch.FloatTensor(self.same_class_extractor_weight)
        # self.same_class_extractor_weight.requires_grad(False)
        self.same_class_number_extractor_weight = base_weight
        self.same_class_number_extractor_weight = torch.FloatTensor(self.same_class_number_extractor_weight)
        # self.same_class_number_extractor_weight.requires_grad(False)

    def psp_forward(self, inputs):
        """Forward function of PSP module."""
        x = inputs[-1]
        psp_outs = [x]
        psp_outs.extend(self.psp_modules(x))
        psp_outs = torch.cat(psp_outs, dim=1)
        output = self.bottleneck(psp_outs)

        return output

    def forward(self, inputs):
        """Forward function."""

        inputs = self._transform_inputs(inputs)

        # build laterals
        laterals = [
            lateral_conv(inputs[i])
            for i, lateral_conv in enumerate(self.lateral_convs)
        ]

        laterals.append(self.psp_forward(inputs))

        # build top-down path
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
            prev_shape = laterals[i - 1].shape[2:]
            laterals[i - 1] = laterals[i - 1] + resize(
                laterals[i],
                size=prev_shape,
                mode='bilinear',
                align_corners=self.align_corners)

        # build outputs
        fpn_outs = [
            self.fpn_convs[i](laterals[i])
            for i in range(used_backbone_levels - 1)
        ]
        # append psp feature
        fpn_outs.append(laterals[-1])

        for i in range(used_backbone_levels - 1, 0, -1):
            fpn_outs[i] = resize(
                fpn_outs[i],
                size=fpn_outs[0].shape[2:],
                mode='bilinear',
                align_corners=self.align_corners)
        fpn_outs = torch.cat(fpn_outs, dim=1)
        output_er = self.fpn_bottleneck(fpn_outs)
        output = self.cls_seg(output_er)
        return output, output_er

    def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg, **kwargs):
        seg_logits, output_er = self.forward(inputs)
        losses = self.losses(seg_logits, output_er, gt_semantic_seg, kwargs['gt_boundary_seg'])
        return losses

    def forward_test(self, inputs, img_metas, test_cfg):
        seg_logits, edge_logits = self.forward(inputs)
        return seg_logits

    @force_fp32(apply_to=('seg_logit', 'output_er'))
    def losses(self, seg_logit, output_er, seg_label, gt_boundary_seg):
        """Compute segmentation loss."""
        loss = dict()
        loss['loss_context'] = self.context_loss(output_er, seg_label, gt_boundary_seg)
        loss['loss_NCE'], loss['loss_CN'] = self.er_loss(output_er, seg_label, seg_logit, gt_boundary_seg)
        loss['loss_NCE'] = loss['loss_NCE']*0.2
        loss['loss_CN'] = loss['loss_CN']*2

        seg_logit = resize(
            input=seg_logit,
            size=seg_label.shape[2:],
            mode='bilinear',
            align_corners=self.align_corners)
        if self.sampler is not None:
            seg_weight = self.sampler.sample(seg_logit, seg_label)
        else:
            seg_weight = None
        seg_label = seg_label.squeeze(1)
        loss['loss_seg'] = self.loss_decode(
            seg_logit,
            seg_label,
            weight=seg_weight,
            ignore_index=self.ignore_index)
        loss['acc_seg'] = accuracy(seg_logit, seg_label)

        return loss

    def context_loss(self, er_input, seg_label, gt_boundary_seg, kernel_size=5):
        seg_label = F.interpolate(seg_label.float(), size=er_input.shape[2:], mode='nearest').long()
        gt_boundary_seg = F.interpolate(gt_boundary_seg.unsqueeze(1).float(), size=er_input.shape[2:], mode='nearest').long()
        context_loss_final = torch.tensor(0.0, device=er_input.device)
        context_loss = torch.tensor(0.0, device=er_input.device)
        gt_b = gt_boundary_seg

        gt_b[gt_b==255]=0
        seg_label_copy = seg_label.clone()
        seg_label_copy[seg_label_copy==255]=0
        gt_b = gt_b*seg_label_copy
        seg_label_one_hot = F.one_hot(seg_label.squeeze(1), num_classes=256)[:,:,:,0:self.num_classes].permute(0,3,1,2)

        b,c,h,w = er_input.shape
        scale_num = b
        for i in range(b):
            cal_mask = (gt_b[i][0]>0).bool()
            if cal_mask.sum()<1:
                scale_num = scale_num-1
                continue

            position = torch.where(gt_b[i][0])
            position_mask = ((kernel_size//2)<=position[0]) * (position[0]<=(er_input.shape[-2]-1-(kernel_size//2))) * ((kernel_size//2)<=position[1]) * (position[1]<=(er_input.shape[-1]-1-(kernel_size//2)))
            position_selected = (position[0][position_mask], position[1][position_mask])
            position_shift_list = []
            for ki in range(kernel_size):
                for kj in range(kernel_size):
                    if ki==kj==(kernel_size//2):
                        continue
                    position_shift_list.append((position_selected[0]+ki-(kernel_size//2),position_selected[1]+kj-(kernel_size//2)))
            # context_loss_batchi = torch.zeros_like(er_input[i].permute(1,2,0)[position_selected][0])
            context_loss_pi = torch.tensor(0.0, device=er_input.device)
            for pi in range(len(position_shift_list)):
                boudary_simi = F.cosine_similarity(er_input[i].permute(1,2,0)[position_selected], er_input[i].permute(1,2,0)[position_shift_list[pi]], dim=1)
                boudary_simi_label = torch.sum(seg_label_one_hot[i].permute(1,2,0)[position_selected] * seg_label_one_hot[i].permute(1,2,0)[position_shift_list[pi]], dim=-1)
                context_loss_pi = context_loss_pi + F.mse_loss(boudary_simi, boudary_simi_label.float())
            context_loss += (context_loss_pi / len(position_shift_list))
        context_loss = context_loss/scale_num
        if torch.isnan(context_loss):
            return context_loss_final
        else:
            return context_loss

    def er_loss(self, er_input, seg_label, seg_logit, gt_boundary_seg):
        shown_class = list(seg_label.unique())
        pred_label = seg_logit.max(dim=1)[1]
        pred_label_one_hot = F.one_hot(pred_label, num_classes=self.num_classes).permute(0,3,1,2)
        seg_label = F.interpolate(seg_label.float(), size=er_input.shape[2:], mode='nearest').long()
        gt_boundary_seg = F.interpolate(gt_boundary_seg.unsqueeze(1).float(), size=er_input.shape[2:], mode='nearest').long()

        gt_b = gt_boundary_seg

        gt_b[gt_b==255]=0
        edge_mask = gt_b.squeeze(1)

        seg_label_one_hot = F.one_hot(seg_label.squeeze(1), num_classes=256)[:,:,:,0:self.num_classes].permute(0,3,1,2)
        if self.same_class_extractor_weight.device!=er_input.device: 
            self.same_class_extractor_weight = self.same_class_extractor_weight.to(er_input.device)
            print("er move:",self.same_class_extractor_weight.device)
        if self.same_class_number_extractor_weight.device!=er_input.device: 
            self.same_class_number_extractor_weight = self.same_class_number_extractor_weight.to(er_input.device)
        # print(self.same_class_number_extractor_weight)
        same_class_extractor = NeighborExtractor5(512)
        same_class_extractor.same_class_extractor.weight.data = self.same_class_extractor_weight
        same_class_number_extractor = NeighborExtractor5(1)
        same_class_number_extractor.same_class_extractor.weight.data = self.same_class_number_extractor_weight

        try:
            shown_class.remove(torch.tensor(255))
        except:
            pass
        # er_input = er_input.permute(0,2,3,1)
        neigh_classfication_loss_total = torch.tensor(0.0, device=er_input.device)
        close2neigh_loss_total = torch.tensor(0.0, device=er_input.device)
        cal_class_num = len(shown_class)
        for i in range(len(shown_class)):
            now_class_mask = seg_label_one_hot[:,shown_class[i],:,:]
            now_pred_class_mask = pred_label_one_hot[:,shown_class[i],:,:]

            now_neighbor_feat = same_class_extractor(er_input*now_class_mask.unsqueeze(1))
            now_correct_neighbor_feat = same_class_extractor(er_input*(now_class_mask*now_pred_class_mask).unsqueeze(1))
            now_class_num_in_neigh = same_class_number_extractor(now_class_mask.unsqueeze(1).float())
            now_correct_class_num_in_neigh = same_class_number_extractor((now_class_mask*now_pred_class_mask).unsqueeze(1).float())

            pixel_cal_mask = (now_class_num_in_neigh.squeeze(1)>=1)*(edge_mask.bool()*now_class_mask.bool()).detach()
            pixel_mse_cal_mask = (now_correct_class_num_in_neigh.squeeze(1)>=1)*(edge_mask.bool()*now_class_mask.bool()*now_pred_class_mask.bool()).detach()
            if pixel_cal_mask.sum()<1 or pixel_mse_cal_mask.sum()<1:
                cal_class_num = cal_class_num - 1
                continue            
            class_forward_feat = now_neighbor_feat/(now_class_num_in_neigh+1e-5)
            class_correct_forward_feat = now_correct_neighbor_feat/(now_correct_class_num_in_neigh+1e-5)

            origin_mse_pixel_feat = er_input.permute(0,2,3,1)[pixel_mse_cal_mask].permute(1,0).unsqueeze(0).unsqueeze(-1)

            neigh_pixel_feat = class_forward_feat.permute(0,2,3,1)[pixel_cal_mask].permute(1,0).unsqueeze(0).unsqueeze(-1)
            neigh_mse_pixel_feat = class_correct_forward_feat.permute(0,2,3,1)[pixel_mse_cal_mask].permute(1,0).unsqueeze(0).unsqueeze(-1)

            neigh_pixel_feat_prediction = F.conv2d(neigh_pixel_feat, weight=self.conv_seg.weight.to(neigh_pixel_feat.dtype).detach(), bias=self.conv_seg.bias.to(neigh_pixel_feat.dtype).detach())

            gt_for_neigh_output = shown_class[i]*torch.ones((1,neigh_pixel_feat_prediction.shape[2],1)).to(er_input.device).long()
            neigh_classfication_loss = F.cross_entropy(neigh_pixel_feat_prediction, gt_for_neigh_output)

            close2neigh_loss = F.mse_loss(origin_mse_pixel_feat, neigh_mse_pixel_feat.detach())
            neigh_classfication_loss_total = neigh_classfication_loss_total + neigh_classfication_loss
            close2neigh_loss_total = close2neigh_loss_total + close2neigh_loss
        if cal_class_num==0:
            return neigh_classfication_loss_total, close2neigh_loss_total
        neigh_classfication_loss_total = neigh_classfication_loss_total / cal_class_num
        close2neigh_loss_total = close2neigh_loss_total / cal_class_num
        return neigh_classfication_loss_total, close2neigh_loss_total

此处附上CBL+Swin-B的完整config:

norm_cfg = dict(type='SyncBN', requires_grad=True)
backbone_norm_cfg = dict(type='LN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    # pretrained='pretrain/swin_base_patch4_window7_224_22k.pth',
    pretrained=None,
    backbone=dict(
        type='SwinTransformer',
        pretrain_img_size=224,
        embed_dims=128,
        patch_size=4,
        window_size=7,
        mlp_ratio=4,
        depths=[2, 2, 18, 2],
        num_heads=[4, 8, 16, 32],
        strides=(4, 2, 2, 2),
        out_indices=(0, 1, 2, 3),
        qkv_bias=True,
        qk_scale=None,
        patch_norm=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.3,
        use_abs_pos_embed=False,
        act_cfg=dict(type='GELU'),
        norm_cfg=dict(type='LN', requires_grad=True)),
    decode_head=dict(
        type='ERUPerHead',
        in_channels=[128, 256, 512, 1024],
        in_index=[0, 1, 2, 3],
        pool_scales=(1, 2, 3, 6),
        channels=512,
        dropout_ratio=0.1,
        num_classes=150,
        norm_cfg=dict(type='SyncBN', requires_grad=True),
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
    auxiliary_head=dict(
        type='FCNHead',
        in_channels=512,
        in_index=2,
        channels=256,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        num_classes=150,
        norm_cfg=dict(type='SyncBN', requires_grad=True),
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', reduce_zero_label=True),
    dict(type='GenerateBoundary', dilation=0.02),
    dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_boundary_seg'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(2048, 512),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=4,
    train=dict(
        type='ADE20KDataset',
        data_root='data/ade/ADEChallengeData2016',
        img_dir='images/training',
        ann_dir='annotations/training',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations', reduce_zero_label=True),
            dict(type='GenerateBoundary', dilation=0.02),
            dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
            dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75),
            dict(type='RandomFlip', prob=0.5),
            dict(type='PhotoMetricDistortion'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_boundary_seg'])
        ]),
    val=dict(
        type='ADE20KDataset',
        data_root='data/ade/ADEChallengeData2016',
        img_dir='images/validation',
        ann_dir='annotations/validation',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(2048, 512),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]),
    test=dict(
        type='ADE20KDataset',
        data_root='data/ade/ADEChallengeData2016',
        img_dir='images/validation',
        ann_dir='annotations/validation',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(2048, 512),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ]))
log_config = dict(
    interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True
optimizer = dict(
    type='AdamW',
    lr=6e-05,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    paramwise_cfg=dict(
        custom_keys=dict(
            absolute_pos_embed=dict(decay_mult=0.0),
            relative_position_bias_table=dict(decay_mult=0.0),
            norm=dict(decay_mult=0.0))))
optimizer_config = dict()
lr_config = dict(
    policy='poly',
    warmup='linear',
    warmup_iters=1500,
    warmup_ratio=1e-06,
    power=1.0,
    min_lr=0.0,
    by_epoch=False)
runner = dict(type='IterBasedRunner', max_iters=160000)
checkpoint_config = dict(by_epoch=False, interval=16000)
evaluation = dict(interval=16000, metric='mIoU', pre_eval=True)
work_dir = 'work_dirs/biou/swin/'
gpu_ids = range(0, 4)
auto_resume = False
euphonium1998 commented 9 months ago

非常感谢您的回答!我目前使用的mmseg是v0.30的版本。没有在使用1.x版本的代码。经过几天的debug,我认为我主要是无法正确加载gt_boundary_seg。我认为可能是您有修改过pipeline,但是忘记在readme中提及。 如果可以的话希望您可以提供给我先前版本的CBL+mmsegmentation代码,以及环境。非常感谢您的帮助! 邮箱:euphonium1998@163.com

dywu98 commented 9 months ago

非常感谢您的回答!我目前使用的mmseg是v0.30的版本。没有在使用1.x版本的代码。经过几天的debug,我认为我主要是无法正确加载gt_boundary_seg。我认为可能是您有修改过pipeline,但是忘记在readme中提及。 如果可以的话希望您可以提供给我先前版本的CBL+mmsegmentation代码,以及环境。非常感谢您的帮助! 邮箱:euphonium1998@163.com

已发送,再次对我混乱的代码管理表示抱歉T.T

yychayitu commented 9 months ago

非常感谢您附上了CBL+Swin-B的config。我使用的是1.1.0版本的mmsegmentation,本想尝试自己改写,但该版本没有from mmcv.runner import auto_fp16, force_fp32(CBLuperhead),dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_boundary_seg'])(train_pipeline)。前者我似乎可以直接注释掉,但后者会影响boundary.py的使用。 2023-11-21 14-27-53屏幕截图

期待您的代码更新。

dywu98 commented 9 months ago

非常感谢您附上了CBL+Swin-B的config。我使用的是1.1.0版本的mmsegmentation,本想尝试自己改写,但该版本没有from mmcv.runner import auto_fp16, force_fp32(CBLuperhead),dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_boundary_seg'])(train_pipeline)。前者我似乎可以直接注释掉,但后者会影响boundary.py的使用。 2023-11-21 14-27-53屏幕截图

期待您的代码更新。

是的 auto_fp16, force_fp32二者只是规定计算的精度,仅影响计算速度和精度,不影响自定义的head。

后面我会基于新的mmseg版本。另外,我发现1.2版本的mmseg提供了一个与本工作的获取边界操作类似的功能,您可以参考其mmseg/datasets/transforms/transforms.py中的GenerateEdge方法。

不过 请注意本方法中产生的边界是具有宽度的,其设置为图片对角线长度的百分比(如ADE20K上为0.02),因此我们使用dilation对宽度进行控制。您可以使用GenerateEdge中的edge_width参数对像素宽度进行控制。

yychayitu commented 9 months ago

非常感谢您的快速回复。我使用GenerateEdge,并将edge_width设为15(图片大小为512),运行后报错: Traceback (most recent call last): File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/tools/train.py", line 108, in main() File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/tools/train.py", line 104, in main runner.train() File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/mmengine/runner/runner.py", line 1735, in train model = self.train_loop.run() # type: ignore File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/mmengine/runner/loops.py", line 278, in run self.run_iter(data_batch) File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/mmengine/runner/loops.py", line 302, in run_iter data_batch, optim_wrapper=self.runner.optim_wrapper) File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/mmengine/model/base_model/base_model.py", line 114, in train_step losses = self._run_forward(data, mode='loss') # type: ignore File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/mmengine/model/base_model/base_model.py", line 340, in _run_forward results = self(*data, mode=mode) File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(input, **kwargs) File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/mmseg/models/segmentors/base.py", line 94, in forward return self.loss(inputs, data_samples) File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/mmseg/models/segmentors/encoder_decoder.py", line 176, in loss loss_decode = self._decode_head_forward_train(x, data_samples) File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/mmseg/models/segmentors/encoder_decoder.py", line 138, in _decode_head_forward_train self.train_cfg) File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/mmseg/models/decode_heads/decode_head.py", line 262, in loss losses = self.loss_by_feat(seg_logits, batch_data_samples) File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/mmseg/models/decode_heads/decode_head.py", line 311, in loss_by_feat align_corners=self.align_corners) File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/mmseg/models/utils/wrappers.py", line 27, in resize return F.interpolate(input, size, scale_factor, mode, align_corners) File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/torch/nn/functional.py", line 3459, in interpolate dim = input.dim() - 2 # Number of spatial dimensions. AttributeError: 'tuple' object has no attribute 'dim'

我尝试在functional.py中打印input的类型,得到结果: <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'tuple'>

请问您知道怎么解决吗?

dywu98 commented 9 months ago

非常感谢您的快速回复。我使用GenerateEdge,并将edge_width设为15(图片大小为512),运行后报错: Traceback (most recent call last): File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/tools/train.py", line 108, in main() File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/tools/train.py", line 104, in main runner.train() File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/mmengine/runner/runner.py", line 1735, in train model = self.train_loop.run() # type: ignore File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/mmengine/runner/loops.py", line 278, in run self.run_iter(data_batch) File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/mmengine/runner/loops.py", line 302, in run_iter data_batch, optim_wrapper=self.runner.optim_wrapper) File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/mmengine/model/base_model/base_model.py", line 114, in train_step losses = self._run_forward(data, mode='loss') # type: ignore File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/mmengine/model/base_model/base_model.py", line 340, in _run_forward results = self(*data, mode=mode) File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(input, **kwargs) File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/mmseg/models/segmentors/base.py", line 94, in forward return self.loss(inputs, data_samples) File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/mmseg/models/segmentors/encoder_decoder.py", line 176, in loss loss_decode = self._decode_head_forward_train(x, data_samples) File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/mmseg/models/segmentors/encoder_decoder.py", line 138, in _decode_head_forward_train self.train_cfg) File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/mmseg/models/decode_heads/decode_head.py", line 262, in loss losses = self.loss_by_feat(seg_logits, batch_data_samples) File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/mmseg/models/decode_heads/decode_head.py", line 311, in loss_by_feat align_corners=self.align_corners) File "/media/oyasumi/CDA0125873789844/yy/mmsegmentation-new/mmseg/models/utils/wrappers.py", line 27, in resize return F.interpolate(input, size, scale_factor, mode, align_corners) File "/home/oyasumi/anaconda3/envs/seg/lib/python3.7/site-packages/torch/nn/functional.py", line 3459, in interpolate dim = input.dim() - 2 # Number of spatial dimensions. AttributeError: 'tuple' object has no attribute 'dim'

我尝试在functional.py中打印input的类型,得到结果: <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'tuple'>

请问您知道怎么解决吗?

抱歉目前我还在研究,之后正式更新会通知您

yychayitu commented 1 month ago

您好,我也无法加载 'gt_boundary_seg',似乎是mmseg缺少对应设置。能告诉我怎么修改,或者发送源代码到771263786@qq.com吗?非常感谢

dywu98 commented 1 month ago

您好,我也无法加载 'gt_boundary_seg',似乎是mmseg缺少对应设置。能告诉我怎么修改,或者发送源代码到771263786@qq.com吗?非常感谢

已发送~,如果有问题,欢迎继续沟通询问~