guanfuchen / objdet

实现常用的one-stage和two-stage目标检测网络 华为媒体研究院 图文Caption、OCR识别、图视文多模态理解与生成相关方向工作或实习欢迎咨询 15757172165 https://guanfuchen.github.io/media/hw_zhaopin_20220724_tiny.jpg
54 stars 15 forks source link

yolo loss advance #12

Open guanfuchen opened 4 years ago

guanfuchen commented 4 years ago
import math
import random
import time
import logging as log
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def bbox_ious(boxes1, boxes2):
    """ Compute IOU between all boxes from ``boxes1`` with all boxes from ``boxes2``.

    Args:
        boxes1 (torch.Tensor): List of bounding boxes
        boxes2 (torch.Tensor): List of bounding boxes

    Note:
        List format: [[xc, yc, w, h],...]
    """
    b1_len = boxes1.size(0)
    b2_len = boxes2.size(0)

    b1x1, b1y1 = (boxes1[:, :2] - (boxes1[:, 2:4] / 2)).split(1, 1)
    b1x2, b1y2 = (boxes1[:, :2] + (boxes1[:, 2:4] / 2)).split(1, 1)
    b2x1, b2y1 = (boxes2[:, :2] - (boxes2[:, 2:4] / 2)).split(1, 1)
    b2x2, b2y2 = (boxes2[:, :2] + (boxes2[:, 2:4] / 2)).split(1, 1)
    # print('b1x1:', b1x1)
    # print('b1y1:', b1y1)

    dx = (b1x2.min(b2x2.t()) - b1x1.max(b2x1.t())).clamp(min=0)
    dy = (b1y2.min(b2y2.t()) - b1y1.max(b2y1.t())).clamp(min=0)
    intersections = dx * dy

    areas1 = (b1x2 - b1x1) * (b1y2 - b1y1)
    areas2 = (b2x2 - b2x1) * (b2y2 - b2y1)
    unions = (areas1 + areas2.t()) - intersections

    return intersections / unions

class YoloLossAdvance(nn.modules.loss._Loss):
    """ Computes yolo loss from darknet network output and target annotation.

    Args:
        num_classes (int): number of categories
        anchors (list): 2D list representing anchor boxes (see :class:`lightnet.network.Darknet`)
        coord_scale (float): weight of bounding box coordinates
        noobject_scale (float): weight of regions without target boxes
        object_scale (float): weight of regions with target boxes
        class_scale (float): weight of categorical predictions
        thresh (float): minimum iou between a predicted box and ground truth for them to be considered matching
        seen (int): How many images the network has already been trained on.
    """
    def __init__(self, num_classes, anchors, anchors_mask, reduction=32, seen=0, coord_scale=1.0, noobject_scale=1.0, object_scale=1.0, class_scale=1.0, thresh=0.5, head_idx=0):
        super(YoloLossAdvance, self).__init__()
        self.num_classes = num_classes
        self.num_anchors = len(anchors_mask)
        self.anchor_step = len(anchors[0])
        self.anchors = torch.Tensor(anchors) / float(reduction)
        self.anchors_mask = anchors_mask
        self.reduction = reduction      
        self.seen = seen
        self.head_idx = head_idx

        self.coord_scale = coord_scale
        self.noobject_scale = noobject_scale
        self.object_scale = object_scale
        self.class_scale = class_scale
        self.thresh = thresh

        self.info = {'avg_iou': 0, 'class': 0, 'obj': 0, 'no_obj': 0, 
                'recall50': 0, 'recall75': 0, 'obj_cur': 0, 'obj_all': 0,
                'coord_xy': 0, 'coord_wh': 0}

        # criterion
        self.mse = nn.MSELoss(reduce=False)
        self.bce = nn.BCELoss(reduce=False)
        self.smooth_l1 = nn.SmoothL1Loss(reduce=False)
        self.ce = nn.CrossEntropyLoss(size_average=False)

    # def forward(self, output, target, seen=None):
    def forward(self, output, bbox, label, seen=None):
        """ Compute Yolo loss.
        """
        # Parameters
        nB = output.size(0)
        nA = self.num_anchors
        nC = self.num_classes
        nH = output.size(2)
        nW = output.size(3)
        device = output.device

        if seen is not None:
            self.seen = seen
        else:
            self.seen += nB

        self.anchors = self.anchors.to(device)

        # Get x,y,w,h,conf,cls
        output = output.view(nB, nA, -1, nH*nW)
        coord = torch.zeros_like(output[:, :, :4])
        coord[:, :, :2] = output[:, :, :2].sigmoid()    # tx,ty
        coord[:, :, 2:4] = output[:, :, 2:4]            # tw,th
        conf = output[:, :, 4].sigmoid()
        if nC > 1:
            cls = output[:, :, 5:].contiguous().view(nB*nA, nC, nH*nW).transpose(1, 2).contiguous().view(-1, nC)

        # Create prediction boxes
        # time consuming
        # pred_boxes_start_time = time.time()
        pred_boxes = torch.zeros(nB*nA*nH*nW, 4, dtype=torch.float, device=device)
        lin_x = torch.linspace(0, nW-1, nW).to(device).repeat(nH, 1).view(nH*nW)
        lin_y = torch.linspace(0, nH-1, nH).to(device).repeat(nW, 1).t().contiguous().view(nH*nW)
        # print('nA:', nA)
        anchor_w = self.anchors[self.anchors_mask, 0].view(nA, 1).to(device)
        anchor_h = self.anchors[self.anchors_mask, 1].view(nA, 1).to(device)
        # anchor_w = 10
        # anchor_h = 13

        pred_boxes[:, 0] = (coord[:, :, 0].detach() + lin_x).view(-1)
        pred_boxes[:, 1] = (coord[:, :, 1].detach() + lin_y).view(-1)
        pred_boxes[:, 2] = (coord[:, :, 2].detach().exp() * anchor_w).view(-1)
        pred_boxes[:, 3] = (coord[:, :, 3].detach().exp() * anchor_h).view(-1)
        # pred_boxes_end_time = time.time()
        # print('pred_boxes cost time:', pred_boxes_end_time - pred_boxes_start_time)

        # Get target values
        build_targets_start_time = time.time()
        # coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, tcoord, tconf, tcls = self.build_targets(pred_boxes, target, nH, nW)
        coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, tcoord, tconf, tcls = self.build_targets(pred_boxes, bbox, label, nH, nW)
        build_targets_end_time = time.time()
        print('build_targets cost time:', build_targets_end_time - build_targets_start_time)

        # coord
        coord_mask = coord_mask.expand_as(tcoord)[:,:,:2] # 0 = 1 = 2 = 3, only need first two element
        coord_center, tcoord_center = coord[:,:,:2], tcoord[:,:,:2]
        coord_wh, tcoord_wh = coord[:,:,2:], tcoord[:,:,2:]
        if nC > 1:
            tcls = tcls[cls_mask].view(-1).long()

            cls_mask = cls_mask.view(-1, 1).repeat(1, nC)
            cls = cls[cls_mask].view(-1, nC)

        # criteria
        self.bce = self.bce.to(device)
        self.mse = self.mse.to(device)
        self.smooth_l1 = self.smooth_l1.to(device)
        self.ce = self.ce.to(device)

        bce = self.bce
        mse = self.mse
        smooth_l1 = self.smooth_l1
        ce = self.ce

        # Compute losses
        loss_coord_center = 2.0 * 1.0 * self.coord_scale * (coord_mask*bce(coord_center, tcoord_center)).sum()
        loss_coord_wh = 2.0 * 1.5 * self.coord_scale * (coord_mask*smooth_l1(coord_wh, tcoord_wh)).sum()
        self.loss_coord = loss_coord_center + loss_coord_wh

        loss_conf_pos = 1.0 * self.object_scale * (conf_pos_mask * bce(conf, tconf)).sum()
        loss_conf_neg = 1.0 * self.noobject_scale * (conf_neg_mask * bce(conf, tconf)).sum() 
        self.loss_conf = loss_conf_pos +  loss_conf_neg 

        if nC > 1 and cls.numel() > 0:
            self.loss_cls = self.class_scale * 1.0 * ce(cls, tcls)
            cls_softmax = F.softmax(cls, 1)
            t_ind = torch.unsqueeze(tcls, 1).expand_as(cls_softmax)      
            class_prob = torch.gather(cls_softmax, 1, t_ind)[:, 0]
        else:
            self.loss_cls = torch.tensor(0.0, device=device)
            class_prob = torch.tensor(0.0, device=device)

        obj_cur = max(self.info['obj_cur'], 1)
        self.info['class'] = class_prob.sum().item() / obj_cur
        self.info['obj'] = (conf_pos_mask*conf).sum().item() / obj_cur
        self.info['no_obj'] = (conf_neg_mask*conf).sum().item() / output.numel()
        self.info['coord_xy'] = (coord_mask * mse(coord_center, tcoord_center)).sum().item() / obj_cur
        self.info['coord_wh'] = (coord_mask* mse(coord_wh, tcoord_wh)).sum().item() / obj_cur
        self.printInfo()

        self.loss_tot = self.loss_coord + self.loss_conf + self.loss_cls
        return self.loss_tot

    # def build_targets(self, pred_boxes, ground_truth, nH, nW):
    def build_targets(self, pred_boxes, bbox, label, nH, nW):
        """ Compare prediction boxes and targets, convert targets to network output tensors """
        # return self.__build_targets_brambox(pred_boxes, ground_truth, nH, nW)
        return self.__build_targets_brambox(pred_boxes, bbox, label, nH, nW)

    # def __build_targets_brambox(self, pred_boxes, ground_truth, nH, nW):
    def __build_targets_brambox(self, pred_boxes, bbox, label, nH, nW):
        """ Compare prediction boxes and ground truths, convert ground truths to network output tensors """
        # Parameters
        nB = len(label)
        nA = self.num_anchors
        nAnchors = nA*nH*nW
        nPixels = nH*nW
        device = pred_boxes.device

        # Tensors
        conf_pos_mask = torch.zeros(nB, nA, nH*nW, requires_grad=False, device=device)
        conf_neg_mask = torch.ones(nB, nA, nH*nW, requires_grad=False, device=device)
        coord_mask = torch.zeros(nB, nA, 1, nH*nW, requires_grad=False, device=device)
        cls_mask = torch.zeros(nB, nA, nH*nW, requires_grad=False, dtype=torch.uint8, device=device)
        tcoord = torch.zeros(nB, nA, 4, nH*nW, requires_grad=False, device=device)
        tconf = torch.zeros(nB, nA, nH*nW, requires_grad=False, device=device)
        tcls = torch.zeros(nB, nA, nH*nW, requires_grad=False, device=device)

        recall50 = 0
        recall75 = 0
        obj_all = 0
        obj_cur = 0
        iou_sum = 0
        for b in range(nB):
            if len(label[b]) == 0:   # No gt for this image
                continue

            obj_all += len(label[b])

            # Build up tensors
            cur_pred_boxes = pred_boxes[b*nAnchors:(b+1)*nAnchors]
            if self.anchor_step == 4:
                anchors = self.anchors.clone()
                anchors[:, :2] = 0
            else:
                anchors = torch.cat([torch.zeros_like(self.anchors), self.anchors], 1)
            gt = torch.FloatTensor(bbox[b])
            # gt = torch.zeros(len(label[b]), 4, device=device)
            # for i, bbox_b in enumerate(bbox[b]):
            #     gt[i, 0] = bbox_b[0] / self.reduction
            #     gt[i, 1] = bbox_b[1] / self.reduction
            #     gt[i, 2] = bbox_b[2] / self.reduction
            #     gt[i, 3] = bbox_b[3] / self.reduction
            # for i, anno in enumerate(ground_truth[b]):
            #     gt[i, 0] = (anno.x_top_left + anno.width/2) / self.reduction
            #     gt[i, 1] = (anno.y_top_left + anno.height/2) / self.reduction
            #     gt[i, 2] = anno.width / self.reduction
            #     gt[i, 3] = anno.height / self.reduction

            # Set confidence mask of matching detections to 0
            iou_gt_pred = bbox_ious(gt, cur_pred_boxes)
            mask = (iou_gt_pred > self.thresh).sum(0) >= 1
            conf_neg_mask[b][mask.view_as(conf_neg_mask[b])] = 0

            # Find best anchor for each gt
            gt_wh = gt.clone()
            gt_wh[:, :2] = 0
            iou_gt_anchors = bbox_ious(gt_wh, anchors)
            _, best_anchors = iou_gt_anchors.max(1)

            # Set masks and target values for each gt
            # time consuming
            # for i, anno in enumerate(ground_truth[b]):
            for obj_id, bbox_b in enumerate(bbox[b]):
                gi = min(nW-1, max(0, int(gt[obj_id, 0])))
                gj = min(nH-1, max(0, int(gt[obj_id, 1])))
                cur_n = best_anchors[obj_id]
                # print('cur_n:', cur_n)
                if cur_n in self.anchors_mask:
                    best_n = self.anchors_mask.index(cur_n)
                else:
                    continue

                iou = iou_gt_pred[obj_id][best_n*nPixels+gj*nW+gi]
                # debug information
                obj_cur += 1
                recall50 += (iou > 0.5).item()
                recall75 += (iou > 0.75).item()
                iou_sum += iou.item()

                # if anno.ignore:
                #     conf_pos_mask[b][best_n][gj*nW+gi] = 0
                #     conf_neg_mask[b][best_n][gj*nW+gi] = 0
                # else:
                # coord_mask[b][best_n][0][gj*nW+gi] = 2 - anno.width*anno.height/(nW*nH*self.reduction*self.reduction)
                bbox_width = bbox_b[2]-bbox_b[0]
                bbox_height = bbox_b[3]-bbox_b[1]
                coord_mask[b][best_n][0][gj*nW+gi] = 2 - bbox_width*bbox_height/(nW*nH*self.reduction*self.reduction)
                # coord_mask[b][best_n][0][gj*nW+gi] = 2 - anno.width*anno.height/(nW*nH*self.reduction*self.reduction)
                cls_mask[b][best_n][gj*nW+gi] = 1
                conf_pos_mask[b][best_n][gj*nW+gi] = 1
                conf_neg_mask[b][best_n][gj*nW+gi] = 0
                tcoord[b][best_n][0][gj*nW+gi] = gt[obj_id, 0] - gi
                tcoord[b][best_n][1][gj*nW+gi] = gt[obj_id, 1] - gj
                tcoord[b][best_n][2][gj*nW+gi] = math.log(gt[obj_id, 2]/self.anchors[cur_n, 0])
                tcoord[b][best_n][3][gj*nW+gi] = math.log(gt[obj_id, 3]/self.anchors[cur_n, 1])
                tconf[b][best_n][gj*nW+gi] = 1
                # tcls[b][best_n][gj*nW+gi] = anno.class_id
                tcls[b][best_n][gj*nW+gi] = label[b][obj_id]
        # loss informaion
        self.info['obj_cur'] = obj_cur
        self.info['obj_all'] = obj_all
        if obj_cur == 0: 
            obj_cur = 1
        self.info['avg_iou'] = iou_sum / obj_cur 
        self.info['recall50'] = recall50 / obj_cur
        self.info['recall75'] = recall75 / obj_cur

        return coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, tcoord, tconf, tcls

    def printInfo(self):
        info = self.info
        info_str = 'AVG IOU %.4f, Class %.4f, Obj %.4f, No obj %.4f, ' \
                '.5R %.4f, .75R %.4f, Cur obj %3d, All obj %3d, Coord xy %.4f, Coord wh %.4f' % \
                (info['avg_iou'], info['class'], info['obj'], info['no_obj'],
                        info['recall50'], info['recall75'], info['obj_cur'], info['obj_all'],
                        info['coord_xy'], info['coord_wh'])
        log.info('Head %d\n%s' % (self.head_idx, info_str))

        # reset
        self.info = {'avg_iou': 0, 'class': 0, 'obj': 0, 'no_obj': 0, 
                'recall50': 0, 'recall75': 0, 'obj_cur': 0, 'obj_all': 0,
                'coord_xy': 0, 'coord_wh': 0}

if __name__ == '__main__':
    # pass
    anchors = [(10,13),  (16,30),  (33,23), (30,61), (62,45), (59,119), (116,90), (156,198), (373,326)]
    anchors_mask = [(6,7,8), (3,4,5), (0,1,2)]
    loss = YoloLossAdvance(num_classes=3, anchors=anchors, anchors_mask=anchors_mask[0])
    # B*[[N1*4], [N2*4]]
    bbox = [
        [[0, 1, 2, 3], [4, 5, 6, 7]],
        [[0, 1, 2, 3], [4, 5, 6, 7]],
    ]
    # B*[[N1*1], [N2*1]]
    label = [
        [0, 0],
        [1, 1]
    ]
    W = 1024
    H = 720
    B = 64
    C = 3
    bbox = []
    label = []
    for b in range(B):
        N = random.randint(0, B)

        bbox_b = []
        for n in range(N):
            x1 = random.randint(1, W)
            y1 = random.randint(1, H)
            x2 = random.randint(1, W)
            y2 = random.randint(1, H)
            bbox_b.append([x1, y1, x2, y2])
        bbox.append(bbox_b)

        label_b = []
        for n in range(N):
            cls = random.randint(0, 1)
            label_b.append(cls)
        label.append(label_b)
    print(bbox)
    print(label)

    reduces = [64, 32, 16]

    for reduce in reduces:
        W1 = W//reduce
        H1 = H//reduce
        output = torch.zeros((len(bbox), (len(anchors_mask[0]) * (4 + 1 + C)), H1, W1))

        loss_start_time = time.time()
        loss(output, bbox, label)
        loss_end_time = time.time()
        print('reduce {} loss cost time:', reduce, loss_end_time - loss_start_time)