lilanxiao / Rotated_IoU

Differentiable IoU of rotated bounding boxes using Pytorch
MIT License
412 stars 62 forks source link

Minimum Bounding Box implementation Rotating Calipers #16

Closed JonathanCMitchell closed 2 years ago

JonathanCMitchell commented 3 years ago

The minimum bounding box method has three types (aligned, pca, and smallest). Is the "smallest" method using the "Rotating Calipers" algorithm? I was looking for a torch implementation of this and if this is it then thanks!

lilanxiao commented 3 years ago

I've tried to implement rotating calipers with Pytorch, but I didn't make it, lol. It's non-trivial to make this algo work with the "tensor-style computing". Coding directly with CUDA is probably necessary. As a quick and dirty solution, I simply use the brute force search instead.

According to some math, an edge of the smallest enclosing box must be collinear with an edge of the polygon. With 8 vertices (two boxes), there are 8x7/2 = 28 possible ways to choose this edge. I just check all the possibilities and find the smallest box. Not an efficient method, but it works.

JonathanCMitchell commented 3 years ago

I've seen some implementations of rotating in Numpy here: https://chadrick-kwag.net/python-implementation-of-rotating-caliper-algorithm/ maybe it won't be too hard to implement.

By the way. If I use this implementation in numpy off the shelf, it won't lead to convergence right? Because if it isn't in torch then I can't backpropagate through it? I am not sure we need to backprop through the convex hull tho.

lilanxiao commented 3 years ago

The Numpy implementation looks good. It should be a good start point for a Pytorch implementation.

Yeah, the chain of gradient would be broken. You would probably get an error from Pytorch when you do the back-propagation.

JonathanCMitchell commented 3 years ago

@lilanxiao Ok so I implemented it using PyTorch.

def min_bounding_rect_torch(hull_points_2d):
    """
    hull_points_2d: array of hull points. each element should have [x,y] format
    """
    # Compute edges (x2-x1,y2-y1)
    edges = t.zeros( (len(hull_points_2d)-1,2) ).cuda() # empty 2 column array
    for i in range( len(edges) ):
        edge_x = hull_points_2d[i+1,0] - hull_points_2d[i,0]
        edge_y = hull_points_2d[i+1,1] - hull_points_2d[i,1]
        edges[i] = t.FloatTensor([edge_x,edge_y])

    # Calculate edge angles   atan2(y/x)
    edge_angles = t.zeros( (len(edges)) ) # empty 1 column array
    for i in range( len(edge_angles) ):
        edge_angles[i] = t.atan2( edges[i,1], edges[i,0] )

    # Check for angles in 1st quadrant
    for i in range( len(edge_angles) ):
        edge_angles[i] = t.abs( edge_angles[i] % (np.pi/2) ) # want strictly positive answers

    edge_angles = t.unique(edge_angles)

    min_bbox = t.FloatTensor([0, float("inf"), 0, 0, 0, 0, 0, 0]).cuda() # rot_angle, area, width, height, min_x, max_x, min_y, max_y
    for i in range(len(edge_angles)):
        R = t.FloatTensor([ [ t.cos(edge_angles[i]), t.cos(edge_angles[i]-(math.pi/2)) ], [ t.cos(edge_angles[i]+(math.pi/2)), t.cos(edge_angles[i]) ] ]).cuda()

        rot_points = t.mm(R, hull_points_2d.t()) # [check]

        # Find min/max x,y points
        # Workaround for nanmax nanmin
        # In rot_points make two sets. One for maxes and one for mins 
        # Retain grad on clone() https://discuss.pytorch.org/t/how-does-clone-interact-with-backpropagation/8247/6
        INTMAX = float('inf')
        INTMIN = -1 * float('inf')
        rot_points_for_max = rot_points.clone()
        rot_points_for_max.retain_grad()
        rot_points_for_max[rot_points_for_max != rot_points_for_max] = INTMIN

        rot_points_for_min = rot_points.clone()
        rot_points_for_min.retain_grad()
        rot_points_for_min[rot_points_for_min != rot_points_for_min] = INTMAX

        min_x = t.min(rot_points_for_min[0], dim=0)[0]
        max_x = t.max(rot_points_for_max[0], dim=0)[0]

        min_y = t.min(rot_points_for_min[1], dim=0)[0]
        max_y = t.max(rot_points_for_max[1], dim=0)[0]

        # Calculate height/width/area of this bounding rectangle
        width = max_x - min_x
        height = max_y - min_y
        area = width*height
        # This is where I lose the grad!
        # Store the smallest rect found first (a simple convex hull might have 2 answers with same area)
        if (area < min_bbox[1]):
            min_bbox = t.cuda.FloatTensor([edge_angles[i], area, width, height, min_x, max_x, min_y, max_y ])

    # Re-create rotation matrix
    angle = min_bbox[0]
    R = t.FloatTensor([ [ t.cos(angle), t.cos(angle-(math.pi/2)) ], [ t.cos(angle+(math.pi/2)), t.cos(angle) ] ]).cuda()

    # Proj points
    proj_points = t.mm(R, hull_points_2d.t())

    # min/max x,y points are against baseline
    min_x = min_bbox[4]
    max_x = min_bbox[5]
    min_y = min_bbox[6]
    max_y = min_bbox[7]

    # Calculate center point and project onto rotated frame
    center_x = (min_x + max_x)/2
    center_y = (min_y + max_y)/2
    center_point = t.mm(t.unsqueeze(t.FloatTensor([center_x, center_y]).cuda(),dim=0), R)[0]

    # Calculate corner points and project onto rotated frame
    corner_points = t.zeros((4,2)).cuda() # empty 2 column array
    corner_points[0] = t.mm(t.unsqueeze(t.cuda.FloatTensor([max_x, min_y]),dim=0), R)[0]
    corner_points[1] = t.mm(t.unsqueeze(t.cuda.FloatTensor([min_x, min_y]),dim=0), R)[0]
    corner_points[2] = t.mm(t.unsqueeze(t.cuda.FloatTensor([min_x, max_y]),dim=0), R)[0]
    corner_points[3] = t.mm(t.unsqueeze(t.cuda.FloatTensor([max_x, max_y]),dim=0), R)[0]

    # Looks like I lose the grad somewhere!
    return [angle, min_bbox[1], min_bbox[2], min_bbox[3], center_point, corner_points]
   # rot_angle, area, width, height, center_point, corner_points

So there are two concerns I have with my implementation. (1) I have to use this workaround to do the nanmax operation, (2) The gradient is lost after computing the area! so I am not sure if this will backprop

Assume I want to use it like this: pred = model(input) where input is an autograd variable with requires_grad=True. pred is not an autograd variable with requires_grad=True gt is ground truth bounding box

hull = convex_hull(pred) # convex hull points
(angle, convex_area, w, h, center, corners) = min_bounding_rect_torch(hull) # hull points are autograd variables w requires_grad=True
# Assume I got the union and iou from somewhere and that they are both autograd variables w/ requires_grad = True
giou_val = (convex_area - union) / (convex_area + 1e-16)
giou_loss += 1. - (iou - (convex_area - union) / (convex_area + 1e-16))
# then eventually I will call:
giou_loss.backward()

The problem is that convex_area when inspected does not have requires_grad=True. I lose the requires_grad=True part on the area variable within in min_bounding_rect_torch() function. It however, does have the property is_leaf=True

Just wondering how I can make this thing differentiable or if I even need to make the area part differentiable.

lilanxiao commented 3 years ago

@JonathanCMitchell It's not a good idea to pass values directly to the FloatTensor() constructor, because the constructor doesn't copy the gradient. Instead, you should do the copy after the new Tensor is created. Try this small demo:

import torch

def test1():
    a = torch.rand(3, 3).requires_grad_()
    b = torch.rand(3, 3).requires_grad_()
    c = a + b
    # define a new tensor, copy value and gradient
    d = torch.zeros(2)
    d[0] = c[0, 0]
    d[1] = c[2, 2]
    loss = torch.mean(d)
    loss.backward()
    print("gradient:")
    print("a:", a.grad)
    print("b:", b.grad)

def test2():
    a = torch.rand(3, 3).requires_grad_()
    b = torch.rand(3, 3).requires_grad_()
    c = a + b
    # pass element values directly to the constructor, gradient is lost
    d = torch.FloatTensor([c[0,0], c[2, 2]]).requires_grad_()
    loss = torch.mean(d)
    loss.backward()
    print("gradient:")
    print("a:", a.grad)
    print("b:", b.grad)

if __name__ == "__main__":
    test1()
    test2()

You would get the proper gradient with the first test case.

a: tensor([[0.5000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5000]])
b: tensor([[0.5000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5000]])

But you get None with the second.

a: None
b: None
JonathanCMitchell commented 3 years ago

@lilanxiao thanks so much for the feedback! Ok so I copied the function differently using your advice. I changed

min_bbox = t.cuda.FloatTensor([edge_angles[i], area, width, height, min_x, max_x, min_y, max_y ])

TO


if (area < min_bbox[1]):
            min_bbox[0] = edge_angles[i]
            min_bbox[1] = area
            min_bbox[2] = width
            min_bbox[3] = height
            min_bbox[4] = min_x
            min_bbox[5] = max_x
            min_bbox[6] = min_y
            min_bbox[7] = max_y

And it looks like I am no longer losing my gradient. But do I even need this gradient?

lilanxiao commented 3 years ago

In GIoU loss, this term should be minimized:

(convex_area - union) / (convex_area + 1e-16)

If the convex_area is not differentiable, optimizing the GIoU loss would maximize the area of the union, but it doesn't have a direct impact on the area of the enclosing box.

If the convex_area is differentiable, optimizing the GIoU loss would maximize the area of the union AND minimize the area of the enclosing box.

So, I think the answer depends on what you expect. If you want to minimize the area of the enclosing box in your task (like common object detection tasks), you should make the area differentiable.

If not, I think the loss function also works without that gradient. In this case, the area is simply a non-differentiable factor, which makes the GIoU loss scale-invariant.

JonathanCMitchell commented 3 years ago

And you are using the first or second case?

lilanxiao commented 3 years ago

As I said, I simply use the brute force search to get the smallest enclosing box. The method has O(n^2) complexity but it's easy to implement in Pytorch. My implementation in min_enclosing_box.py is fully Pytorch-based and is thus differentiable. So, I'm using the first case. I've tested it in 3D object detection projects, my implementation is not very fast but works well.

I'm not sure why you need exactly the rotating caliper, perhaps you have more points so that the O(n^2) complexity of the brute force method is not acceptable. But notice that the rotating caliper doesn't work with raw points. To use the rotating caliper, you have to:

  1. find the convex hull of all points, which has the complexity of O(n*log(n)), and is really hard to implement on GPU.
  2. sort the vertices of the convex hull in clockwise order.

Then, the rotating caliper algorithm works with the sorted convex hull.

If you don't have a large number of points, I would recommend the brute force search. It provides the same results as the rotating caliper. The code in min_enclosing_box.py assumes that you have only 8 points. But the code can be easily modified to work with an arbitrary number of points.

JonathanCMitchell commented 3 years ago

So I have a torch implementation to find the convex hull, and a torch implementation to use those outputs to find the min bounding rect (above). But still found some issues and it didn't converge :(

I got your implementation to converge (video below) but there is a jitter in the predictions. This simple test case uses a very simple DNN with a single input and a single output. The plots show the Prediction (orange) and the GT (green). We can see that there is jitter in the predictions. https://user-images.githubusercontent.com/13068956/113040900-30f87d00-914e-11eb-9432-b5f2bd94f52f.mp4

So I wanted to also plot the minimum bounding box around it to see if that is changing, but I don't know how to receive it out of the smallest_enclosing_box function.

Thanks again!

lilanxiao commented 3 years ago

Ah, I have read the code before. It mixes Pytorch operators with native python control flow. Theoretically, it should work (although the mixed code is usually slow and doesn't work in batch). There might be some technical issue so the back-propagation goes wrong. Also, you might get unexpected results, if the vertices of the convex hull are not sorted.

I think the jitter is reasonable. It's the common behavior of Gradient Descent. The algorithm usually cannot exactly reach the minimum. Instead, the result moves back and forth around the minimum.

Yeah, the function doesn't return the center and rotation of the box, but only its size. If you set the ´verbose´ argument to True, the function returns the index of the edge, to which the minimum box is collinear (see the test case in ´min_enclsoing_box.py´). This information might help with the validation. But more information is not available.