facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.81k stars 1.32k forks source link

Potential IOU compute Bug in box3d_overlap #1771

Open YLouWashU opened 7 months ago

YLouWashU commented 7 months ago

box3d_overlap IOU computation value larger than 1

box3d_overlap computes incorrect IOU value which sometimes could be value even larger than 1. Note this bug is different from the bug fix mentioned here in github and actually happens with a low eps setting.

Instructions To Reproduce the Issue:

Min replicating code (data: bbox_ab_data.pth.zip )

import torch
from pytorch3d.ops import box3d_overlap

# unzip attached file, and fill its path here 
[bbox_ab_data.pth.zip](https://github.com/facebookresearch/pytorch3d/files/14826992/bbox_ab_data.pth.zip)

bbox_ab_data_file = # FILEPATH
bbox_ab_data = torch.load(bbox_ab_data_file)

box_a = bbox_ab_data["box_a"].unsqueeze(0)
box_b = bbox_ab_data["box_b"].unsqueeze(0)
print(box_a)
print(box_b)

# compute IOU
vol_in, iou = box3d_overlap(box_a, box_b)
print(vol_in, iou)
"""
expected output: 

tensor([[1.0696]]) tensor([[1.1747]])
"""

We validated this issue by several different methods such as random sampling, convex hull implementation, objectron implementation, and the Pytorch3D's IOU is indeed giving larger than 1 values.

gkioxari commented 7 months ago

I am looking into this.

When I am running your code with the boxes you provided, I get ValueError: Plane vertices are not coplanar (I am using the default eps=1e-4 threshold for this). So likely your data are not perfect cuboids. Where did you get the boxes from?

zhengkang86 commented 7 months ago

Hi @gkioxari , thanks for looking into this issue! We used the following code to generate pairs of 3D boxes with small random perturbations.

import math
from typing import Callable, Tuple

import numpy as np
import torch
from pytorch3d.ops import box3d_overlap
from pytorch3d.transforms import euler_angles_to_matrix
from torch import tensor

def get_random_dim_and_T(
    num_sample: int = 100,
    max_scale_factor: float = 1.05,
    max_translation_factor: float = 0.05,
    max_rotation_radian: float = math.pi / 20,
    rand_fn: Callable = torch.rand,
    random_seed: int = 42,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Get some random dimensions and transformation matrices

    Args:
        num_sample: number of samples to generate.
        max_scale_factor: maximum scale factor to perturb object box dimensions.
        max_translation_factor: maximum translation factor to perturb object box center.
        max_rotation_radian: maximum rotation radian to perturb object box rotation.
        rand_fn: random function to use.
        random_seed: random seed to use.

    Returns:
        perturbed_dimensions: (num_sample, 3) tensor of perturbed object dimensions
        perturbed_T_ref_obj: (num_sample, 3, 4) tensor of perturbed transformation matrices from
            reference to object coordinate
    """
    dimension = tensor([1, 1, 1])
    translation = tensor([0, 0, 0])
    rotation_angle = tensor([0, 0, 0])
    R_ref_obj = euler_angles_to_matrix(rotation_angle, convention="XYZ")
    T_ref_obj = torch.cat((R_ref_obj, translation.unsqueeze(1)), dim=1)

    # Set the random seed
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    min_scale_factor = 1 / max_scale_factor
    assert max_scale_factor >= min_scale_factor

    # perturb dimensions
    scale_factors = []
    for _ in [0, 1, 2]:
        scale_factors.append(
            rand_fn(num_sample, 1) * (max_scale_factor - min_scale_factor)
            + min_scale_factor
        )
    scale_factors = torch.cat(scale_factors, dim=1)
    perturbed_dimensions = dimension.repeat(num_sample, 1) * scale_factors

    # perturb translation
    translation_offsets = []
    for _ in [0, 1, 2]:
        translation_offsets.append(
            (rand_fn(num_sample, 1) * 2 - 1) * max_translation_factor
        )
    translation_offsets = torch.cat(translation_offsets, dim=1)
    pertubed_translations = T_ref_obj[:, 3].repeat(
        num_sample, 1
    ) + translation_offsets * dimension.repeat(num_sample, 1)

    # perturb rotation
    euler_angles_offsets = []
    for _ in [0, 1, 2]:
        euler_angles_offsets.append(
            (rand_fn(num_sample, 1) * 2 - 1) * max_rotation_radian
        )
    euler_angles_offsets = torch.cat(euler_angles_offsets, dim=1)
    rotation_offsets = euler_angles_to_matrix(euler_angles_offsets, convention="XYZ")
    pertubed_rotations = T_ref_obj[:, :3].repeat(num_sample, 1, 1) @ rotation_offsets

    perturbed_T_ref_obj = torch.cat(
        [pertubed_rotations, pertubed_translations.unsqueeze(-1)], dim=-1
    )

    return perturbed_dimensions, perturbed_T_ref_obj

def get_cuboid_corners(half_extents: torch.Tensor) -> torch.Tensor:
    """
    https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/ops/iou_box3d.py#L111)

    (4) +---------+. (5)
        | ` .     |  ` .
        | (0) +---+-----+ (1)
        |     |   |     |
    (7) +-----+---+. (6)|
        ` .   |     ` . |
        (3) ` +---------+ (2)
    NOTE: Throughout this implementation, we assume that boxes
    are defined by their 8 corners exactly in the order specified in the
    diagram above for the function to give correct results. In addition
    the vertices on each plane must be coplanar.
    As an alternative to the diagram, this is a unit bounding
    box which has the correct vertex ordering:

    box_corner_vertices = [
        [0, 0, 0],
        [1, 0, 0],
        [1, 1, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 1],
        [1, 1, 1],
        [0, 1, 1],
    ]

    Args:
        half_extents: Half extents of the cuboid in the format of (N, 3)

    Returns:
        corners: Corners of the cuboid in the format of (N, 8, 3)
    """
    corners = torch.tensor(
        [
            [-1, -1, -1],
            [1, -1, -1],
            [1, 1, -1],
            [-1, 1, -1],
            [-1, -1, 1],
            [1, -1, 1],
            [1, 1, 1],
            [-1, 1, 1],
        ],
        dtype=torch.float32,
        device=half_extents.device,
    )

    # Scale the corners of the unit box
    corners = corners * half_extents.unsqueeze(1)

    return corners

def batch_transform_points(
    points_in_B: torch.tensor, T_A_B: torch.Tensor
) -> torch.Tensor:
    """
    Return point_in_A = R_A_B @ points_in_B + t_A_B in shape [N x M x 3]
    Args:
        points_in_B (torch.Tensor): NxMx3 tensor for a batch N of M 3D points in frame B,
          corresponding to each transformation matrix
        T_A_B (torch.Tensor): Nx3x4 transformation matrices from B to A

    Returns:
        points_in_A (torch.Tensor): NxMx3 tensor of a batch N of M 3D points in frame A, after
            transformation
    """

    # Reshape points to (N, 3, M) for batch matrix multiplication
    points_in_B_reshaped = points_in_B.permute(0, 2, 1)
    M = points_in_B_reshaped.shape[-1]

    R_A_B = T_A_B[:, :, :3]
    t_A_B = T_A_B[:, :, 3]
    points_in_A = (
        torch.bmm(R_A_B, points_in_B_reshaped) + t_A_B.unsqueeze(-1).repeat(1, 1, M)
    ).permute(0, 2, 1)

    return points_in_A

# generate random 3D boxes with small perturbation
random_dim, random_T_ref_obj = get_random_dim_and_T(
    num_sample=1000,
    max_scale_factor=1.05,
    max_translation_factor=0.05,
    max_rotation_radian=0.05,
)
num_a, num_b = 10, 10

dim_a = random_dim[:num_a]
t_world_a = random_T_ref_obj[:num_a, :, 3]
R_world_a = random_T_ref_obj[:num_a, :, :3]

dim_b = random_dim[num_a : num_a + num_b]
t_world_b = random_T_ref_obj[num_a : num_a + num_b, :, 3]
R_world_b = random_T_ref_obj[num_a : num_a + num_b, :, :3]

T_world_a = torch.cat((R_world_a, t_world_a.unsqueeze(-1)), dim=-1)
T_world_b = torch.cat((R_world_b, t_world_b.unsqueeze(-1)), dim=-1)

boxes_a = batch_transform_points(get_cuboid_corners(dim_a / 2), T_world_a)
boxes_b = batch_transform_points(get_cuboid_corners(dim_b / 2), T_world_b)

# compute iou
box_a = boxes_a[5].unsqueeze(0)
box_b = boxes_b[4].unsqueeze(0)
print(box_a)
print(box_b)
vol_in, iou = box3d_overlap(box_a, box_b)
print(vol_in, iou)

Below is the output we got:

tensor([[[-0.5315, -0.4913, -0.4733],
         [ 0.4789, -0.5244, -0.4560],
         [ 0.5116,  0.4529, -0.4952],
         [-0.4987,  0.4860, -0.5126],
         [-0.5468, -0.4520,  0.4926],
         [ 0.4636, -0.4851,  0.5100],
         [ 0.4963,  0.4922,  0.4707],
         [-0.5141,  0.5253,  0.4533]]])
tensor([[[-0.4757, -0.4932, -0.5059],
         [ 0.5480, -0.4848, -0.4593],
         [ 0.5414,  0.5224, -0.4987],
         [-0.4822,  0.5140, -0.5453],
         [-0.5210, -0.4548,  0.4832],
         [ 0.5026, -0.4464,  0.5298],
         [ 0.4961,  0.5608,  0.4904],
         [-0.5275,  0.5524,  0.4438]]])
tensor([[1.0696]]) tensor([[1.1747]])
gkioxari commented 7 months ago

I looked into this and submitted a PR last night. Because you are using a small perturbation to generate boxes you are occasionally running into some numerical issues caused by the small values in the geometric computations in the code. I think @bottler is taking care of merging the PR so let me know if this fix does it for you.

zhengkang86 commented 7 months ago

Hi @gkioxari , thanks for the fix PR! It does work for the case we reported earlier. But later I tried on a couple of more cases and still get IOU over 1. So I changed the epsilon threshold to 1e-6 and didn't see >1 IOUs anymore. Do you think there is any potential issues to just make it 1e-6?

YLouWashU commented 7 months ago

Thanks @gkioxari and @bottler for the timely fix! Can you share with us more details on what exactly is triggering this IOU>1?

@zhengkang86 In your test, how many random cases did you try until you don't get IOU > 1?

zhengkang86 commented 7 months ago

I tried 100 boxes x 100 boxes, so 10000 pairs. I think that should be enough.

luchsonice commented 7 months ago

I assume this to be due to the same numerical issues. It is also possible to get IoUs lower than 1 that are incorrect, if for example using these boxes below (visually indistinguishable from another, see image) you get an IoU of 0.8128

import torch
from pytorch3d.ops import box3d_overlap

corners1 = torch.tensor([[
            [ 0.2411, -0.1752,  1.2247],
            [ 0.1951, -0.4194,  1.7741],
            [ 0.2036,  0.4826,  2.1757],
            [ 0.2495,  0.7267,  1.6263],
            [-0.2920, -0.1549,  1.1903],
            [-0.3380, -0.3991,  1.7396],
            [-0.3295,  0.5029,  2.1412],
            [-0.2835,  0.7471,  1.5919]]])

corners2 = torch.tensor([[
            [ 0.2390, -0.1764,  1.2246],
            [ 0.1930, -0.4205,  1.7740],
            [ 0.2055,  0.4813,  2.1759],
            [ 0.2515,  0.7254,  1.6265],
            [-0.2940, -0.1536,  1.1901],
            [-0.3400, -0.3978,  1.7395],
            [-0.3274,  0.5040,  2.1414],
            [-0.2815,  0.7482,  1.5920]]])

vol, iou = box3d_overlap(corners1, corners2)

print(iou[0][0])

vis_result

zhengkang86 commented 7 months ago

I assume this to be due to the same numerical issues. It is also possible to get IoUs lower than 1 that are incorrect, if for example using these boxes below (visually indistinguishable from another, see image) you get an IoU of 0.8128

import torch
from pytorch3d.ops import box3d_overlap

corners1 = torch.tensor([[
            [ 0.2411, -0.1752,  1.2247],
            [ 0.1951, -0.4194,  1.7741],
            [ 0.2036,  0.4826,  2.1757],
            [ 0.2495,  0.7267,  1.6263],
            [-0.2920, -0.1549,  1.1903],
            [-0.3380, -0.3991,  1.7396],
            [-0.3295,  0.5029,  2.1412],
            [-0.2835,  0.7471,  1.5919]]])

corners2 = torch.tensor([[
            [ 0.2390, -0.1764,  1.2246],
            [ 0.1930, -0.4205,  1.7740],
            [ 0.2055,  0.4813,  2.1759],
            [ 0.2515,  0.7254,  1.6265],
            [-0.2940, -0.1536,  1.1901],
            [-0.3400, -0.3978,  1.7395],
            [-0.3274,  0.5040,  2.1414],
            [-0.2815,  0.7482,  1.5920]]])

vol, iou = box3d_overlap(corners1, corners2)

print(iou[0][0])

vis_result

I tried to run the IOU calculation with the fix. It returns 0.9944.

YLouWashU commented 7 months ago

Hey @gkioxari , I'm wondering when can we ship the fix in https://github.com/facebookresearch/pytorch3d/pull/1772 into an official release? And also did the PR incorporate an even lower EPS value suggested by @zhengkang86 ? Thanks!