facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.7k stars 1.3k forks source link

What is meant with non-unique 6D representations? #1620

Closed PottedRosePetal closed 1 year ago

PottedRosePetal commented 1 year ago

The documentation for matrix_to_rotation_6d says "Note that 6D representation is not unique.". I went over the paper a little bit, but it seems to be a little bit over my head. What exactly is meant with that? I tried out the following code and couldnt find anything indicating non-uniqueness?

import torch
from pytorch3d.transforms import euler_angles_to_matrix
from pytorch3d.transforms import axis_angle_to_quaternion
from pytorch3d.transforms import matrix_to_quaternion
from pytorch3d.transforms import quaternion_to_axis_angle
from pytorch3d.transforms import matrix_to_euler_angles
from pytorch3d.transforms import matrix_to_rotation_6d
from pytorch3d.transforms import rotation_6d_to_matrix

num_disc_angles = 100

discrete_angles = torch.linspace(0, 2*torch.pi, num_disc_angles)
angle_grid = torch.meshgrid(discrete_angles, discrete_angles, discrete_angles, indexing='xy')
angle_tensor = torch.stack(angle_grid, dim=-1).reshape(-1, 3)
total_angles = angle_tensor.shape[0]
correct_quat = 0
correct_6d = 0
for angles in angle_tensor:
    quat = axis_angle_to_quaternion(angles)
    angles_quat = quaternion_to_axis_angle(quat)
    if not are_tensors_approx_equal(angles, angles_quat):
        print("Quaternions:",angles, angles_quat) 
    else:
        correct_quat += 1

    matrix = euler_angles_to_matrix(angles, "XYZ")
    rot_6d = matrix_to_rotation_6d(matrix)
    #model training using rot6d
    rot_mat_6d = rotation_6d_to_matrix(rot_6d)
    angles_6d = matrix_to_euler_angles(rot_mat_6d, "XYZ")

    if not are_tensors_approx_equal(angles, angles_quat):
        print("Quaternions:",angles, angles_quat) 
    else:
        correct_quat += 1
    if not are_tensors_approx_equal(matrix, rot_mat_6d):
        print("6D:",angles, torch.abs(angles_6d))
    else:
        correct_6d += 1
print()
print(correct_6d, correct_quat, total_angles)

for the quaternion part it seems like the cutoff is pi, then it behaves differently. I think I read something about that somewhere, but I think that would be nice to include in the docs if that is intended behaviour. For the last element in my output I get: Quaternions: tensor([6.2832, 6.2832, 6.2832]) tensor([-0.9720, -0.9720, -0.9720]) which seems insane to me tbh. But I dont really know quaternions so thats that. Its obvious that the angles are broken due to the gimbal lock, but that shouldnt apply for rotations besides a yaw of pi/2, at least it didnt for the 6D repr.

I was just wondering if I need to prepare for some bad surprises once I use the 6D representation to train my model.

bottler commented 1 year ago

The 6D representation involves using 3 numbers in intervals to represent a rotation. Rotation space is compact - there is no "edge" in any direction, but the 3 numbers have "edges" (which somehow "wrap") in all directions. It's not just extreme values of yaw? In all rotation directions there must be a discontinuity in representations. You need to work out yourself if this matters for your work, we can't help.