chyueG / HAMBox-high-anchors-mining

hambox: Delving into Online High-quality Anchors Mining for Detecting Outer Faces
6 stars 0 forks source link

More discussion on this project #2

Open zehuichen123 opened 3 years ago

zehuichen123 commented 3 years ago

Hi, I am working on Widerface recently and my codebase is based on RetinaFace_Pytorch. I am quite interested in HAMBox, could you share your code so that we can both have a try? Currently, this repo only contains code about topk loss and Focal loss. Thanks a lot!

zehuichen123 commented 3 years ago

Hi, @chyueG! Thanks for your code. Given the code you provided and multibox_loss.py, I've tried to re-implement one with a much simpler style. However, the performance is not satisfying, with only 86.0 on hard set with vanilla Res50 backbone(w/o OAM achieving 91.2 on hard with mstest). Are you still working on this project so that we can recheck the code? Here is my implementation:

import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.box_utils import *

def focal_loss(input, target, alpha=0.25, gamma=2):
    min_thres = 1e-8
    pos_res = -torch.pow(1 - input, gamma) * target * torch.log(torch.clamp(input, min=min_thres))
    neg_res = -torch.pow(input, gamma) * (1 - target) * torch.log(torch.clamp(1 - input, min=min_thres))
    return torch.sum(pos_res) * alpha + torch.sum(neg_res) * (1 - alpha)

def RA_focal_loss(input, target, ious, alpha=0.25, gamma=2):
    min_thres = 1e-8
    pos_res = -torch.pow(1 - input, gamma) * target * torch.log(torch.clamp(input, min=min_thres))
    # neg_res = -torch.pow(input, gamma) * (1 - target) * torch.log(torch.clamp(1 - input, min=min_thres))
    # return torch.sum(pos_res) * alpha + torch.sum(neg_res) * (1 - alpha)
    return torch.sum(pos_res * ious) * alpha

class HAMLoss(nn.Module):
    def __init__(self, img_size=640, OAM=False, rank=0):
        super(HAMLoss, self).__init__()
        # self.num_classes = num_cls
        # self.variance = variance

        self.variance = [0.1, 0.2]
        self.K = 3
        self.T1 = 0.35
        self.T2 = 0.5
        self.T = 0.8
        self.alpha = 0.25
        self.gamma = 2
        self.img_size = img_size
        self.rank = rank
        self.OAM = OAM

    def forward(self, predictions, priors, targets):
        """Multibox Loss
        Args:
            predictions (tuple): A tuple containing loc preds, conf preds,
            and prior boxes from SSD net.
                conf shape: torch.size(batch_size,num_priors,num_classes)
                loc shape: torch.size(batch_size,num_priors,4)
                priors shape: torch.size(num_priors,4)
            ground_truth (tensor): Ground truth boxes and labels for a batch,
                shape: [batch_size,num_objs,5] (last idx is the label).
        """
        if len(predictions) == 3:
            has_landm = True
            loc_data, conf_data, landm_data = predictions
        else:
            has_landm = False
            loc_data, conf_data = predictions   # whether to compute landmarks loss NOTE currently NO
        batch_size = loc_data.size(0)
        prior_size = priors.size(0)

        cls_labels_list = []; reg_labels_list = []
        if has_landm:
            valid_labels_list = []; landm_labels_list = []
        if self.OAM:
            select_idx_list = []; compensate_labels_list = []; compensate_ious_list = []
        for idx in range(batch_size):
            truths = targets[idx][:, :4].data.float()   # x1y1x2y2
            labels = targets[idx][:, -1].data.float()
            landms = targets[idx][:, 4:14].data.float()

            origin_truth = truths * self.img_size
            valid_idx = torch.min(origin_truth[:, 2]-origin_truth[:, 0], origin_truth[:,3]-origin_truth[:, 1]) >= 10
            truths = truths[valid_idx]; labels = labels[valid_idx]; landms = landms[valid_idx]
            if truths.shape[0] == 0:
                cls_labels = torch.zeros((prior_size), dtype=torch.bool).cuda(self.rank)
                reg_labels = torch.zeros((prior_size, 4)).cuda(self.rank)
                if has_landm:
                    landm_labels = torch.zeros((prior_size, 10)).cuda(self.rank)
                    valid_labels = torch.zeros((prior_size,)).cuda(self.rank)#only refer to landm valid
                if self.OAM:
                    select_idx = torch.ones((prior_size), dtype=torch.bool).cuda(self.rank) #defaultly, if all bg, we use all bg.
                    compensate_idx = torch.zeros((prior_size), dtype=torch.bool).cuda(self.rank)
                    compensate_ious = torch.zeros((prior_size)).cuda(self.rank)
            else:
                iou = jaccard(truths, point_form(priors))
                max_overlaps, argmax_overlaps = iou.max(0)
                cls_labels = max_overlaps >= self.T1
                # cls_labels[gt_argmax_overlaps] = 1    # allow_low_quality_matching
                match_gts = truths[argmax_overlaps]
                reg_labels = encode(match_gts, priors, self.variance)
                if has_landm:
                    matches_landm = landms[argmax_overlaps]
                    landm_labels = encode_landm(matches_landm, priors, self.variance)
                    valid_labels = labels[argmax_overlaps] == 1
                if self.OAM:
                    loc_pred = loc_data[idx]
                    loc_pred = loc_pred.detach()        # block gradient from IoU computation
                    loc_pred_bbox = decode(loc_pred, priors, self.variance)
                    regress_iou = jaccard(truths, loc_pred_bbox)        # GT_size, anchor_size
                    reg_max_overlaps, _ = regress_iou.max(0)
                    low_quality_anchors = reg_max_overlaps < self.T1
                    neg_idx = low_quality_anchors * ~cls_labels
                    pos_idx = cls_labels == 1
                    select_idx = neg_idx | pos_idx
                    # find compensated anchors
                    compensate_idx = cls_labels == -2    # initalize all zero labels(faster than torch.zeros().cuda())
                    compensate_ious = cls_labels * 0.0
                    num_gt = iou.shape[0]
                    for ii in range(num_gt):
                        iou_per_gt = iou[ii]
                        num_pos_anchor = (iou_per_gt > self.T1).sum()
                        if num_pos_anchor >= self.K:
                            continue
                        else:
                            compensate_num = int(self.K - num_pos_anchor)
                            reg_iou_per_gt = regress_iou[ii]
                            reg_topk_iou_per_gt, reg_argtopk_iou_per_gt = torch.topk(reg_iou_per_gt, k=self.K)
                            for k_idx in range(self.K):
                                anchor_idx = reg_argtopk_iou_per_gt[k_idx]; anchor_iou = reg_topk_iou_per_gt[k_idx]
                                if anchor_iou < self.T:
                                    break
                                if anchor_iou >= self.T and cls_labels[anchor_idx] != 1:
                                    # print(anchor_idx, compensate_ious[anchor_idx], anchor_iou)
                                    compensate_idx[anchor_idx] = 1
                                    compensate_ious[anchor_idx] = max(compensate_ious[anchor_idx], anchor_iou)
                                    compensate_num -= 1
                                if compensate_num == 0:
                                    break
            cls_labels_list.append(cls_labels)
            reg_labels_list.append(reg_labels)
            if has_landm:
                landm_labels_list.append(landm_labels)
                valid_labels_list.append(valid_labels)
            if self.OAM:
                select_idx_list.append(select_idx)
                compensate_labels_list.append(compensate_idx)
                compensate_ious_list.append(compensate_ious)
        # generate labels
        cls_labels = torch.stack(cls_labels_list, dim=0)
        reg_labels = torch.stack(reg_labels_list, dim=0)
        if has_landm:
            landm_labels = torch.stack(landm_labels_list, dim=0)
            valid_labels = torch.stack(valid_labels_list, dim=0)
        if self.OAM:
            select_idxs = torch.stack(select_idx_list, dim=0)
            compensate_labels = torch.stack(compensate_labels_list, dim=0)
            compensate_ious = torch.stack(compensate_ious_list, dim=0)

        # cls
        conf_data = conf_data.view(-1, 1); cls_labels = cls_labels.view(-1, 1)
        if self.OAM:
            select_idxs = select_idxs.view(-1, 1)
            cls_loss = focal_loss(torch.sigmoid(conf_data[select_idxs]), cls_labels[select_idxs].float(), alpha=0.25, gamma=2.0)
        else:
            cls_loss = focal_loss(torch.sigmoid(conf_data), cls_labels.float(), alpha=0.25, gamma=2.0)
        # regress
        loc_data = loc_data.view(-1, 4); reg_labels = reg_labels.view(-1, 4)
        reg_loss = F.smooth_l1_loss(loc_data[cls_labels.view(-1,)], reg_labels[cls_labels.view(-1,)], reduction='none')   # shape(bs * anchor, 4)
        # landmark
        if has_landm:
            landm_valid_labels = cls_labels * valid_labels.view(-1,1)
            landm_data = landm_data.view(-1, 10); landm_labels = landm_labels.view(-1, 10)
            landm_loss = F.smooth_l1_loss(landm_data, landm_labels, reduction='none')
            landm_loss = landm_loss * landm_valid_labels
        # normalize loss

        # NOTE OAM different from normal matching
        num_pos_anchors = cls_labels.sum().item()
        if has_landm:
            num_pos_landms = landm_valid_labels.sum().item()
        #NOTE for compability with OAM
        compensate_cls_loss = 0
        compensate_reg_loss = 0
        if self.OAM:
            compensate_labels = compensate_labels.view(-1, 1)
            compensate_ious = compensate_ious.view(-1, 1)
            compensate_conf_data = conf_data[compensate_labels]
            compensate_ious = compensate_ious[compensate_labels]
            num_compensate_anchor = compensate_conf_data.shape[0]
            if num_compensate_anchor != 0:
                compensate_cls_loss = RA_focal_loss(torch.sigmoid(compensate_conf_data), 1, compensate_ious)
                compensate_cls_loss = compensate_cls_loss.sum() / max(num_compensate_anchor, 1)
                compensate_reg_loss = F.smooth_l1_loss(loc_data[compensate_labels.view(-1,)], reg_labels[compensate_labels.view(-1,)], reduction='none')
                compensate_reg_loss = compensate_reg_loss.sum() / max(num_compensate_anchor, 1)
        # print(num_pos_anchors)
        reg_loss = reg_loss.sum() / max(num_pos_anchors, 1) + compensate_reg_loss
        cls_loss = cls_loss.sum() / max(num_pos_anchors, 1) + compensate_cls_loss
        loss_list = [reg_loss, cls_loss]
        if has_landm:
            landm_loss = landm_loss.sum() / max(num_pos_landms, 1)
            loss_list.append(landm_loss)
        return loss_list
zehuichen123 commented 3 years ago

https://github.com/chyueG/HAMBox-high-anchors-mining/blob/6e22ce5592689ebcafe89211d7b95e16896882a8/ham_loss.py#L88 one question here, what if there are already 3 positive anchors matching to the same GT while the top best regression one is not in these 3 positive ones. This code may generate totally 3+K positive anchors for one GT?

zehuichen123 commented 3 years ago

https://github.com/chyueG/HAMBox-high-anchors-mining/blob/6e22ce5592689ebcafe89211d7b95e16896882a8/ham_loss.py#L77 I think you should detach gradient from loc_pred to avoid gradient propagation from iou computation, which is mentioned in your README.

chyueG commented 3 years ago

I refer to Diou code