amdegroot / ssd.pytorch

A PyTorch Implementation of Single Shot MultiBox Detector
MIT License
5.12k stars 1.74k forks source link

I don't understand the process of calculating IoU in the function "def nms" described in "box_utils.py". #562

Open ukpana opened 3 years ago

ukpana commented 3 years ago

I don't understand the process of calculating IoU in the function "def nms" described in "box_utils.py".

I understand that IOU = "area of overlapping parts" / "area of all parts".

I also understand that the basic principle of SSD is that the SD model outputs only Offset and Confidence information. And we know that the Offset information is also (▵cx , ▵cy , ▵w , ▵h). And we already know that the bounding box information coming into the function is the value after calculation by variance (0.1,0.2). decoded_boxes = decode(loc_data[i], prior_data, self.variance)

(1) In the following four lines of the implementation "box_utils.py", "def nms(boxes, scores, overlap=0.5, top_k=200):", I don't understand why the values of x1, y1, x2, and y2 are changed as the minimum and maximum values for the result with the highest confidence.

xx1 = torch.clamp(xx1, min=x1[i])
yy1 = torch.clamp(yy1, min=y1[i])
xx2 = torch.clamp(xx2, max=x2[i])
yy2 = torch.clamp(yy2, max=y2[i])

I think it's related to NMS functions, but could anyone please tell me?

(2)Also, area[i] in the union formula is a variable that represents the area of the result with the highest confidence level. "rem_areas" is a variable that arranges "areas" in the same order as the sort order of score. "inter" is the value of the calculation using xx1, xx2, yy1, and yy2 calculated in (1). Why do we subtract "inter" from rem_areas?

union = (rem_areas - inter) + area[i].
IoU=inter/union

Help me.

def nms(boxes, scores, overlap=0.5, top_k=200):
    """Apply non-maximum suppression at test time to avoid detecting too many
    overlapping bounding boxes for a given object.
    Args:
        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
        scores: (tensor) The class predscores for the img, Shape:[num_priors].
        overlap: (float) The overlap thresh for suppressing unnecessary boxes.
        top_k: (int) The Maximum number of box preds to consider.
    Return:
        The indices of the kept boxes with respect to num_priors.
    """

    keep = scores.new(scores.size(0)).zero_().long()
    if boxes.numel() == 0:
        return keep
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    area = torch.mul(x2 - x1, y2 - y1)
    v, idx = scores.sort(0)  # sort in ascending order
    # I = I[v >= 0.01]
    idx = idx[-top_k:]  # indices of the top-k largest vals
    xx1 = boxes.new()
    yy1 = boxes.new()
    xx2 = boxes.new()
    yy2 = boxes.new()
    w = boxes.new()
    h = boxes.new()

    # keep = torch.Tensor()
    count = 0
    while idx.numel() > 0:
        i = idx[-1]  # index of current largest val
        # keep.append(i)
        keep[count] = i
        count += 1
        if idx.size(0) == 1:
            break
        idx = idx[:-1]  # remove kept element from view
        # load bboxes of next highest vals
        torch.index_select(x1, 0, idx, out=xx1)
        torch.index_select(y1, 0, idx, out=yy1)
        torch.index_select(x2, 0, idx, out=xx2)
        torch.index_select(y2, 0, idx, out=yy2)
        # store element-wise max with next highest score
        xx1 = torch.clamp(xx1, min=x1[i])
        yy1 = torch.clamp(yy1, min=y1[i])
        xx2 = torch.clamp(xx2, max=x2[i])
        yy2 = torch.clamp(yy2, max=y2[i])
        w.resize_as_(xx2)
        h.resize_as_(yy2)
        w = xx2 - xx1
        h = yy2 - yy1
        # check sizes of xx1 and xx2.. after each iteration
        w = torch.clamp(w, min=0.0)
        h = torch.clamp(h, min=0.0)
        inter = w*h
        # IoU = i / (area(a) + area(b) - i)
        rem_areas = torch.index_select(area, 0, idx)  # load remaining areas)
        union = (rem_areas - inter) + area[i]
        IoU = inter/union  # store result in iou
        # keep only elements with an IoU <= overlap
        idx = idx[IoU.le(overlap)]
    return keep, count