CaoWGG / CenterNet-CondInst

Instance Segmentation based on CenterNet and CondInst
MIT License
166 stars 27 forks source link

关于CenterNet后接segmentation的请教 #13

Open NeuZhangQiang opened 3 years ago

NeuZhangQiang commented 3 years ago

非常感谢您分享的代码。

由于我暂时还没有配置成功代码,所以只能干看代码,没法调试。

有一个问题想请教一下,就是原始的CenterNet的输出有三个分支,分别是 heatmap (W*H*C),offset (W*H*2)和size (W*H*2),然后你这里加了一个seg_feat,这个分支是怎么加的,能介绍一下吗?能否告知是在代码的哪一处?这里的seg_feat它的size是什么样子的?怎么为每个中心点分配一个mask?难道与offset和size一样,预测一个 W*H*W*H的seg_feat?

此外,代码中关于dice loss的计算,我也不是很明白:

    def forward(self, seg_feat, conv_weight, mask,ind, target):
        mask_loss=0.
        batch_size = seg_feat.size(0)
        weight = _tranpose_and_gather_feat(conv_weight, ind)
        h,w = seg_feat.size(-2),seg_feat.size(-1)
        x,y = ind%w,ind/w
        x_range = torch.arange(w).float().to(device=seg_feat.device)
        y_range = torch.arange(h).float().to(device=seg_feat.device)
        y_grid, x_grid = torch.meshgrid([y_range, x_range])
        for i in range(batch_size):
            num_obj = target[i].size(0)
            conv1w,conv1b,conv2w,conv2b,conv3w,conv3b= \
                torch.split(weight[i,:num_obj],[(self.feat_channel+2)*self.feat_channel,self.feat_channel,
                                          self.feat_channel**2,self.feat_channel,
                                          self.feat_channel,1],dim=-1)
            y_rel_coord = (y_grid[None,None] - y[i,:num_obj].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).float())/128.
            x_rel_coord = (x_grid[None,None] - x[i,:num_obj].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).float())/128.
            feat = seg_feat[i][None].repeat([num_obj,1,1,1])
            feat = torch.cat([feat,x_rel_coord, y_rel_coord],dim=1).view(1,-1,h,w)

            conv1w=conv1w.contiguous().view(-1,self.feat_channel+2,1,1)
            conv1b=conv1b.contiguous().flatten()
            feat = F.conv2d(feat,conv1w,conv1b,groups=num_obj).relu()

            conv2w=conv2w.contiguous().view(-1,self.feat_channel,1,1)
            conv2b=conv2b.contiguous().flatten()
            feat = F.conv2d(feat,conv2w,conv2b,groups=num_obj).relu()

            conv3w=conv3w.contiguous().view(-1,self.feat_channel,1,1)
            conv3b=conv3b.contiguous().flatten()
            feat = F.conv2d(feat,conv3w,conv3b,groups=num_obj).sigmoid().squeeze()

            true_mask = mask[i,:num_obj,None,None].float()
            mask_loss+=dice_loss(feat*true_mask,target[i]*true_mask)

        return mask_loss/batch_size

里面还进行了卷积计算?能否说明一下思路?

非常期待您的回复。

BloodLemonS commented 9 months ago

您好,请问这个代码您看懂了吗?您提到的这个问题我现在也比较疑惑,请问方便解答一下吗