open-mmlab / mmcv

OpenMMLab Computer Vision Foundation
https://mmcv.readthedocs.io/en/latest/
Apache License 2.0
5.91k stars 1.65k forks source link

[Bug] the box_iou_rotated convex_hull sorting seems having a problem #2933

Open DanieeelLiu opened 1 year ago

DanieeelLiu commented 1 year ago

Prerequisite

Environment

pytorch 1.9.0 + mmcv

Reproduces the problem - code sample

https://github.com/open-mmlab/mmcv/blob/main/mmcv/ops/csrc/common/box_iou_rotated_utils.hpp#L189~#L212 revised as following:

for (int i = 0; i < num_in; i++) {
    dist[i] = dot_2d<T>(q[i], q[i]);
  }
printf("before sorting\n");
for (int i = 0; i < num_in; i++)
{
    printf("x,y: %f, %f\n", q[i].x, q[i].y);
}

#ifdef __CUDACC__
  // CUDA version
  // In the future, we can potentially use thrust
  // for sorting here to improve speed (though not guaranteed)
  for (int i = 1; i < num_in - 1; i++) {
    for (int j = i + 1; j < num_in; j++) {
      T crossProduct = cross_2d<T>(q[i], q[j]);
      if ((crossProduct < -1e-6) ||
          (fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
        auto q_tmp = q[i];
        q[i] = q[j];
        q[j] = q_tmp;
        auto dist_tmp = dist[i];
        dist[i] = dist[j];
        dist[j] = dist_tmp;
      }
    }
  }
  printf("sorting result\n");
  for (int i = 0; i < num_in; i++)
  {
      printf("x,y: %f, %f\n", q[i].x, q[i].y);
  }

Reproduces the problem - command or script

test_box_iou_rotated_issue.py


import numpy as np
import pytest
import torch

class TestBoxIoURotated:

    @pytest.mark.skipif(
        not torch.cuda.is_available(), reason='requires CUDA support')
    def test_box_iou_rotated_iof_cuda(self):
        from mmcv.ops import box_iou_rotated
        np_boxes1 = np.asarray(
            [[-0.00028137, 0.00601164, 0.00056274, 0.00343669, -1.38236]],
            dtype=np.float32)
        np_boxes2 = np.asarray(
            [[0.00433421, 0.0040969, 0.00866843, 0.0081938, 2.11528]],
            dtype=np.float32)

        boxes1 = torch.from_numpy(np_boxes1).cuda()
        boxes2 = torch.from_numpy(np_boxes2).cuda()

        # test cw angle definition
        ious = box_iou_rotated(boxes1, boxes2, mode='iou', aligned=True)

command: pytest test_box_iou_rotated_issue.py

Reproduces the problem - error message

see the following result

Additional information

before sorting x,y: 0.000000, 0.000000 x,y: 0.000869, 0.000739 x,y: 0.002780, 0.000530 x,y: 0.002675, 0.001083 x,y: -0.000094, 0.000155

sorting result x,y: 0.000000, 0.000000 x,y: -0.000094, 0.000155 x,y: 0.002780, 0.000530 x,y: 0.002675, 0.001083 x,y: 0.000869, 0.000739

From my point of view, the x,y: -0.000094, 0.000155 should be the last one. The sorting result was wrong.

One more thing, if two boxs intersect, the overlapped polygon should be a convex polygon, I wonder if the step 5 here still be needed? https://github.com/open-mmlab/mmcv/blob/main/mmcv/ops/csrc/common/box_iou_rotated_utils.hpp#L245~#L257

grimoire commented 1 year ago

The function is used to get the convex hull with graham scan. The sort is performed on the distance to the origin point. Since (-0.000094, 0.000155) is the nearest point to the origin, I think this result is expected. Sort is not the final step of the graham scan, 245~257 is the scan step(which is indicated by the name), It is part of the algorithm (together with the sort and other steps)

DanieeelLiu commented 1 year ago

The 245~257, scan step will delete the point (-0.000094, 0.000155) according to its rule, so there will only 4 points left. after scanning: x,y: 0.000000, 0.000000 x,y: 0.002780, 0.000530 x,y: 0.002675, 0.001083 x,y: 0.000869, 0.000739

while if the sorting result is correct, all five points will be included in the convex. And finally the area difference between the area with a correct sorting result (convex with 5 points, 0.217719) and wrong result(convex with 4 points, 0.018336). That will lead to tiny difference in the result of IOU, strictly. I‘m wondering if such small difference is acceptable in training so that there is no need to deal with those extreme cases. What your opinion?

grimoire commented 1 year ago

Ok, I see. The kernel chooses a relatively big threshold (1e-6) when sorting. https://github.com/open-mmlab/mmcv/blob/main/mmcv/ops/csrc/common/box_iou_rotated_utils.hpp#L201 Which will ignore points with small angles, leading to The error you found. After replace the threshold to 1e-8:

x,y: 0.000000, 0.000000
x,y: 0.002780, 0.000530
x,y: 0.002675, 0.001083
x,y: 0.000869, 0.000739
x,y: -0.000094, 0.000155

and result 0.0198

Since I have no background in rotated detection, I can not guarantee it is acceptable.

DanieeelLiu commented 1 year ago

Yes. The points with small angles was ignored with a threshold(1e-6).(Your iou result is correct). So will you fix this error with a patch? Further more, this case can be calculated correct with a relatively small threshold(1e-8), what if there are points even more close to the origin?

grimoire commented 1 year ago

Sure, I will fix it soon. The value range of cross product will be affected by the length of the vector. I guess we can normalize the vector before cross_2d

DanieeelLiu commented 1 year ago

OK, Thank you for your reply very much!!! I will wait for the patch. (ps: Will there is a patch for 1.6.1 or only for master?) (ps2: One more thing, if two boxs intersect, the overlapped polygon should be a convex polygon, I wonder if the scan step here still be needed? https://github.com/open-mmlab/mmcv/blob/main/mmcv/ops/csrc/common/box_iou_rotated_utils.hpp#L245~#L257)

grimoire commented 1 year ago

Will there is a patch for 1.6.1 or only for master?

Master only.

One more thing, if two boxs intersect, the overlapped polygon should be a convex polygon, I wonder if the scan step here still be needed?

Theoretically, YES. But I will keep it to prevent missing corner cases.