huawei-noah / Efficient-AI-Backbones

Efficient AI Backbones including GhostNet, TNT and MLP, developed by Huawei Noah's Ark Lab.
4.04k stars 706 forks source link

PViG用于目标检测 #252

Open yuan0038 opened 6 months ago

yuan0038 commented 6 months ago

韩凯大佬,想请教一下您 PViG中的COCO目标检测配置: 个人环境及配置:

①库为mmdet,设备为4×A6000(48G) ②目标检测框架:官方提供的mask_RCNN, 训练时只替换了backbone:PViG_S ③训练方式为1×schedule ④每张卡跑2张图片(即总batch为8),img scale为(1333,800) 然后就出现了如下问题:


CUDA error: an illegal memory access was encountered

查询过网上相关资料,原因可能是显存不够,然后我尝试

  1. img scale 的分辨率调成一半,backbone 仍为PViG_S(参数量为45.8M,与论文相同,模型应该没搭错)
  2. img scale为(1333,800),换了个其他backbone (参数量也为45.8M)
  3. img scale为(1333,800),换成PViG_Ti(参数量为29,3M) 1,2能正常跑,3报错如下:
    x = self.grapher(x)
    File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
    File "/home/guest/workplace/lzy/GraphConvNet/detection/models/gcn_lib/torch_vertex.py", line 189, in forward
    x = self.graph_conv(x, relative_pos)
    File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
    File "/home/guest/workplace/lzy/GraphConvNet/detection/models/gcn_lib/torch_vertex.py", line 138, in forward
    x = super(DyGraphConv2d, self).forward(x, edge_index, y)
    File "/home/guest/workplace/lzy/GraphConvNet/detection/models/gcn_lib/torch_vertex.py", line 111, in forward
    return self.gconv(x, edge_index, y)
    File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
    File "/home/guest/workplace/lzy/GraphConvNet/detection/models/gcn_lib/torch_vertex.py", line 34, in forward
    return self.nn(x)
    File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
    File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
    File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
    File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 757, in forward
    world_size,
    File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/_functions.py", line 80, in forward
    count_all = count_all[mask]
    RuntimeError: CUDA error: an illegal memory access was encounteredFile "

⭐️所以我的问题是,按照标准跑法(虽然batch小了,但是跟标准跑法一样都是平均每张卡都是2张图片),为什么A6000会跑不动PViG_S,恳请大佬指点🌹

iamhankai commented 6 months ago

是不是dilation参数设置不对,太大了,需要改小点

yuan0038 commented 6 months ago

目标检测的backbone 如pvig_s,打印出来的dilation是[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],模型应该是没问题的。 而且我甚至把dilation全改成1试了试,也还是出现上面的问题,就很神奇。😂

class Pyramid_ViG(torch.nn.Module):
    def __init__(self, k,gconv,channels,blocks,n_classes,act,norm,bias,epsilon,use_stochastic,dropout,drop_path,
                 pretrained=None,out_indices=None):
        super().__init__()

        self.pretrained = pretrained
        self.out_indices = out_indices

        self.n_blocks = sum(blocks)
        reduce_ratios = [4, 2, 1, 1]
        dpr = [x.item() for x in torch.linspace(0, drop_path, self.n_blocks)]  # stochastic depth decay rule
        num_knn = [int(x.item()) for x in torch.linspace(k, k, self.n_blocks)]  # number of knn's k
        print(num_knn)
        max_dilation = 49 // max(num_knn)

        self.stem = Stem(out_dim=channels[0], act=act)
        self.pos_embed = nn.Parameter(torch.zeros(1, channels[0], 224 // 4, 224 // 4))
        HW = 224 // 4 * 224 // 4

        self.backbone = nn.ModuleList([])

        #dilation=[min(idx // 4 + 1, max_dilation) for idx in range(sum(blocks))]
        dilation = [1 for i in range(sum(blocks))]
        idx = 0
        for i in range(len(blocks)):
            if i > 0:
                self.backbone.append(Downsample(channels[i - 1], channels[i]))
                HW = HW // 4
            for j in range(blocks[i]):
                self.backbone += [
                    Seq(
                        *[Block(channels[i],num_knn[idx], dilation[idx], gconv, act, norm,
                                bias, use_stochastic, epsilon, reduce_ratios[i],n=HW, drop_path=dpr[idx],
                                relative_pos=True)])
                    ]

                idx += 1
        self.backbone = Seq(*self.backbone)
        print("\u2b50 dilation:",dilation)
        self.init_weights()
        self = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)

    @torch.no_grad()
    def train(self, mode=True):
        super().train(mode)
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
    def init_weights(self):
        logger = get_root_logger()
        print("Pretrained weights being loaded")
        logger.warn('Pretrained weights being loaded')
        ckpt_path = self.pretrained
        ckpt = _load_checkpoint(
            ckpt_path, logger=logger, map_location='cpu')
        print("ckpt keys: ", ckpt.keys())
        if 'state_dict' in ckpt:
            _state_dict = ckpt['state_dict']
        elif 'model' in ckpt:
            _state_dict = ckpt['model']
        else:
            _state_dict = ckpt
        state_dict = _state_dict
        new_state_dict={}
        for k,v in state_dict.items():
            new_k  = k.replace(".grapher",'')
            new_state_dict[new_k]=v
        print(new_state_dict.keys())
        missing_keys, unexpected_keys = \
            self.load_state_dict(new_state_dict, False)
        print("missing_keys: ", missing_keys)
        print("unexpected_keys: ", unexpected_keys)

    def interpolate_pos_encoding(self, x):
        w, h = x.shape[2], x.shape[3]
        p_w, p_h = self.pos_embed.shape[2], self.pos_embed.shape[3]

        if w * h == p_w * p_h and w == h:
            return self.pos_embed

        w0 = w
        h0 = h
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        w0, h0 = w0 + 0.1, h0 + 0.1
        patch_pos_embed = nn.functional.interpolate(
            self.pos_embed,
            scale_factor=(w0 / p_w, h0 / p_h),
            mode='bicubic',
        )
        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
        return patch_pos_embed

    def forward(self, inputs):
        outs=[]
        B, C, H, W = inputs.shape

        x = self.stem(inputs)

        x = x + self.interpolate_pos_encoding(x)

        for i in range(len(self.backbone)):

            x = self.backbone[i](x)
            if i in self.out_indices:
                outs.append(x)

        return outs
  def pvig_s_feat(pretrained=True,**kwargs):
        model = Pyramid_ViG( k=9,  # neighbor num (default:9)
            gconv='mr',  # graph conv layer {edge, mr}
            channels=[80, 160, 400, 640],  # number of channels of deep features
            blocks=[2, 2, 6, 2],  # number of basic blocks in the backbone
            n_classes=1000,  # Dimension of out_channels
            act='gelu',  # activation layer {relu, prelu, leakyrelu, gelu, hswish}
            norm='batch',  # batch or instance normalization {batch, instance}
            bias=True,  # bias of conv layer True or False
            epsilon=0.2,  # stochastic epsilon for gcn
            use_stochastic=False,  # stochastic for gcn, True or False
            dropout=0.0,  # dropout rate
            drop_path=0.0,
            pretrained='../ckpt/pvig_s_82.1.pth.tar',
            out_indices=[1,4,11,14])

        model.default_cfg = _cfg()
        return model