WZMIAOMIAO / deep-learning-for-image-processing

deep learning for image processing including classification and object-detection etc.
GNU General Public License v3.0
22.8k stars 7.96k forks source link

SSD 网络改为ciou损失精度降低了,大佬我是改错了吗 #422

Closed 1757732743 closed 2 years ago

1757732743 commented 2 years ago
    bboxes_in = ploc.permute(0, 2, 1)

    # [batch, label_num, 8732] -> [batch, 8732, label_num]
    # scores_in = plabel.permute(0, 2, 1)
    # print(bboxes_in.is_contiguous())

    # 将预测的回归参数叠加到default box上得到最终的预测边界框
    bboxes_in[:, :, :2] = bboxes_in[:, :, :2] * self.dboxes_xywh[:, :, 2:] + self.dboxes_xywh[:, :, :2]
    bboxes_in[:, :, 2:] = bboxes_in[:, :, 2:].exp() * self.dboxes_xywh[:, :, 2:]

    bboxes_in[:,:, [0, 2]] = bboxes_in[:,:, [0, 2]] * 300
    bboxes_in[:,:, [1, 3]] = bboxes_in[:,:, [1, 3]] * 300
    bboxes_in = bboxes_in.clamp(min=0, max=300)
    # gloc = self._location_vec(gloc)
    gloc = gloc.permute(0, 2, 1)
    # # xianzai huigui canshu,xianzai yijing zhenshi GT(jiale trick de)
    # gloc[:, :, :2] = gloc[:, :, :2] * self.dboxes_xywh[:, :, 2:] + self.dboxes_xywh[:, :, :2]
    # gloc[:, :, 2:] = gloc[:, :, 2:].exp() * self.dboxes_xywh[:, :, 2:]

    gloc[:,:, [0, 2]] =  gloc[:,:, [0, 2]] * 300
    gloc[:,:, [1, 3]] =  gloc[:,:, [1, 3]] * 300
    # start cal ciou_loss
    b1=bboxes_in
    b2=gloc
    b1_xy = b1[..., :2]
    b1_wh = b1[..., 2:4]
    b1_wh_half = b1_wh / 2.
    b1_mins = b1_xy - b1_wh_half
    b1_maxes = b1_xy + b1_wh_half
    # 求出真实框左上角右下角
    b2_xy = b2[..., :2]
    b2_wh = b2[..., 2:4]
    b2_wh_half = b2_wh / 2.
    b2_mins = b2_xy - b2_wh_half
    b2_maxes = b2_xy + b2_wh_half
    # 求真实框和预测框所有的iou
    intersect_mins = torch.max(b1_mins, b2_mins)
    intersect_maxes = torch.min(b1_maxes, b2_maxes)
    # quchu <0 buxiangjiao de
    intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
    intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
    b1_area = b1_wh[..., 0] * b1_wh[..., 1]
    b2_area = b2_wh[..., 0] * b2_wh[..., 1]
    union_area = b1_area + b2_area - intersect_area
    iou = intersect_area / torch.clamp(union_area, min=1e-6)
    # 计算中心的差距
    center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1)
    # 找到包裹两个框的最小框的左上角和右下角
    enclose_mins = torch.min(b1_mins, b2_mins)
    enclose_maxes = torch.max(b1_maxes, b2_maxes)
    enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
    # 计算对角线距离
    enclose_diagonal = torch.sum(torch.pow(enclose_wh, 2), axis=-1)
    ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal, min=1e-6)
    v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(
        b1_wh[..., 0] / torch.clamp(b1_wh[..., 1], min=1e-6)) - torch.atan(
        b2_wh[..., 0] / torch.clamp(b2_wh[..., 1], min=1e-6))), 2)
    alpha = v / torch.clamp((1.0 - iou + v), min=1e-6)
    ciou = ciou - alpha * v
    ciou=1-ciou
WZMIAOMIAO commented 2 years ago

改代码的问题就自己看吧,改损失后效果好不好还取决一些超参数,需要自己炼丹去试。