HKUST-Aerial-Robotics / Stereo-RCNN

Code for 'Stereo R-CNN based 3D Object Detection for Autonomous Driving' (CVPR 2019)
MIT License
690 stars 177 forks source link

Doubt about the inconsistency of RPN cls score computation during inference and training #59

Closed chenchr closed 4 years ago

chenchr commented 4 years ago

Hi peiliang, I just read the code of mono branch, in here seem the softmax is apply on tensor with shape b * 2 * 3h w, and we get the probability, then it is permuted with shape b * h * w * 6 , then reshape to b * h * w3 * 2, then viewed as b * hw3 * 2, and the last dimension is considered as the foreground-background probability, which is inconsistent with original shape you apply softmax on.

chenchr commented 4 years ago

https://github.com/jwyang/fpn.pytorch/issues/6

xmyqsh commented 4 years ago

@chenchr First solution:

            rpn_cls_score_reshape = rpn_cls_score.view(
                                                       rpn_cls_score.size(0),
                                                       2,
                                                       rpn_cls_score.size(1) / 2,
                                                       rpn_cls_score.size(2),
                                                       rpn_cls_score.size(3)).permute(0,1,3,4,2).contiguous()
            rpn_cls_prob_reshape = F.softmax(rpn_cls_score_reshape, 1)
            ##rpn_cls_prob = self.reshape(rpn_cls_prob_reshape, self.nc_score_out)

            # get rpn offsets to the anchor boxes
            rpn_bbox_pred_left_right = self.RPN_bbox_pred_left_right(rpn_conv1)

            rpn_shapes.append([rpn_cls_score.size()[2], rpn_cls_score.size()[3]])
            rpn_cls_scores.append(rpn_cls_score.permute(0, 2, 3, 4, 1).contiguous().view(batch_size, -1, 2))
            rpn_cls_probs.append(rpn_cls_prob.permute(0, 2, 3, 4, 1).contiguous().view(batch_size, -1, 2))

Best solution: change

# define bg/fg classifcation score layer
        self.nc_score_out = 1 * len(self.anchor_ratios) * 2 # 2(bg/fg) * 3 (anchor ratios) * 1 (anchor scale)
        self.RPN_cls_score = nn.Conv2d(512, self.nc_score_out, 1, 1, 0)

to

# define bg/fg classifcation score layer
        self.nc_score_out = 1 * len(self.anchor_ratios) # 1(sigmoid(bg/fg)) * 3 (anchor ratios) * 1 (anchor scale)
        self.RPN_cls_score = nn.Conv2d(512, self.nc_score_out, 1, 1, 0)

and change softmax to sigmoid, change cross_entropy_loss to binary_cross_entropy_loss and replace the related shape process accordingly. Shape process like this:

            rpn_cls_score_reshape = rpn_cls_score.view(rpn_cls_score.size(0),-1)
            rpn_cls_prob_reshape = F.sigmoid(rpn_cls_score_reshape)

            # get rpn offsets to the anchor boxes
            rpn_bbox_pred_left_right = self.RPN_bbox_pred_left_right(rpn_conv1)

            rpn_shapes.append([rpn_cls_score.size()[2], rpn_cls_score.size()[3]])
            rpn_cls_scores.append(rpn_cls_score_reshape)
            rpn_cls_probs.append(rpn_cls_prob_reshape)
chenchr commented 4 years ago

@xmyqsh Thanks. Do you have any experiments results using different form of rpn_cls_score computation ? It seems that whether changing the code or not does not affect the performance, which is weird.