facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.8k stars 1.31k forks source link

transforms.matrix_to_quaternion only has gradients on main diagonal #503

Closed tomforge closed 3 years ago

tomforge commented 3 years ago

The trace and copysign implementation of matrix_to_quaternion means that gradients flow only on the main diagonal of the matrix, used to compute the trace.

From what I understand of the axis-angle representation of rotation, such a loss that changes only the trace implies that the axis of rotation will never change, only the angle.

Indeed, we can see this is the case in this toy example:

import torch
from pytorch3d import transforms
torch.manual_seed(0)

quats = transforms.random_quaternions(2)
init, target = quats[0:1], quats[1:2]
init_mat = transforms.quaternion_to_matrix(init)

mat = torch.tensor(init_mat, requires_grad=True)
optim = torch.optim.Adam([mat], lr=0.05)
for i in range(100):
    optim.zero_grad()
    q = mat2quat3(mat)
    loss = torch.sum((q - target)**2)
    loss.backward()
    optim.step()
    if i % 10 == 0:
        aa = transforms.so3_log_map(mat).detach()
        angle = torch.norm(aa)
        axis = aa / angle
        print(axis, angle)

The output axis never changes. Hence, the init also never converges to the target.

I've tested two other implementations of matrix-to-quaternion conversions, one using axis-angles as an intermediary and the other using the square-root and division implementation. Both seem to be able to flow gradients to all elements of the input matrix and allow the above example to converge. Of course, there might be other issues with the implementations, which I did not test.

def mat2quat2(mat):
    aa = transforms.so3_log_map(mat)
    return transforms.axis_angle_to_quaternion(aa)
def mat2quat3(matrix):
    t = matrix[...,0,0] + matrix[...,1,1] + matrix[...,2,2]
    r = torch.sqrt(1 + t)
    s = 0.5 / r
    w = 0.5 * r
    x = (matrix[...,2,1] - matrix[...,1,2]) * s
    y = (matrix[...,0,2] - matrix[...,2,0]) * s
    z = (matrix[...,1,0] - matrix[...,0,1]) * s
    return torch.stack((w,x,y,z), -1)

For context, I discovered this problem when trying to optimize euler angle parameters (initialized to zero), which had to be converted to quaternions to fit into the training pipeline (for which I used euler_angles_to_matrix followed by matrix_to_quaternion), and was getting no gradients on the parameters.

I have since unblocked myself, but I thought this might be something you'd want to look into.

bottler commented 3 years ago

There's a complication with the backprop function of any function which assumes its input lies on a (differentiable) manifold. There are typically multiple ways to implement such a function. For example, let's consider a function f which takes a tensor of shape (2,) where the second element is the cube of the first, and returns the fifth power of the first. (The requirement that the second element is the cube of the first means that the input is in a certain subset of FloatTensors of shape (2,) which in this case is "smooth" and "curvy", which I am calling the manifold.)

Here are two possible implementations, but there are infinitely many others.

def f1(x):
    return x[0]**5
def f2(x):
    return torch.square(x[0]) * x[1]

These functions will have different outputs from their backwards methods. f1 has no gradient in the direction of x[1]. But if you create a computation graph from inputs which are not restricted, and which has its output on the manifold, and then pass the value through f, then whichever implementation you used will not matter. In fact, considerations of gradient shouldn't lead you to prefer f1 or f2 in the case. For more complicated functions, there might be reasons why the implementation will break the gradient, and there have been some issues like that in PyTorch3D, but neither of these implementations is broken.

t = torch.random(1, requires_grad=True)
x = torch.concatenate([t, t*t*t])
f(x).backward()
print(t.grad) # same answer if f=f1 or f=f2

When you apply the backwards function of a scalar function like this to the value 1 you get the vector derivative (divergence) of the implementation of f. In general, for most implementations, this will not be a tangent to the manifold - i.e. if you take a tiny step in the resulting direction you will move off the manifold. This is not a bug. The backwards function is not an implementation of the "differential map" of the function f, whose output is a tangent, and it would be a waste of time, albeit valid, to make it so.

In your example, rotation matrices are a manifold inside tensors of shape (3,3). You cannot directly learn a rotation matrix by gradient descent. Even if you had the differential map, any finite step would take you off the manifold and you would need to make numerical corrections. Your example code couldn't work. When you write so3_log_map(mat), you are applying so3_log_map to a matrix which is not a rotation matrix, which is not supposed to be defined.

Euler angles are solid, they are not just a manifold, so it is safe to directly learn them by gradient descent. It sounds like you were always trying to learn in Euler angle space, so I don't know what was going wrong. But the problem shouldn't be the form of matrix_to_quaternion.

tomforge commented 3 years ago

I see. Thanks for the detailed explanation! Am I correct to say then that the 6D representation of rotation is suitable for learning via gradient descent as well?

Specific to what I was trying to do, I was indeed trying to learn in Euler angle space. The issue was that I had initialized the euler angles to 0, and implemented the conversion of euler angles to quaternions via euler --(euler_angles_to_matrix)--> matrix --(matrix_to_quaternion)--> quat.

Since matrix only receives gradients on its main diagonal due to matrix_to_quaternion, and its main diagonal comprises only cosine terms of the underlying Euler angles due to euler_angles_to_matrix (which become sine terms in the derivative), , the Euler angles receive 0 gradients when they are at 0 (or n*2pi).

bottler commented 3 years ago

Your case is interesting. Yes, around the identity rotation you don't get gradients.

I think the 6D representation does exactly what you want. They aren't stuck on a manifold. Yes.

github-actions[bot] commented 3 years ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

github-actions[bot] commented 3 years ago

This issue was closed because it has been stalled for 5 days with no activity.