princeton-vl / lietorch

BSD 3-Clause "New" or "Revised" License
679 stars 50 forks source link

act does not seem to work with multiple points #3

Closed hturki closed 3 years ago

hturki commented 3 years ago

Hi,

Thanks for the great project and paper. I'm trying to use to act operator on multiple points, but can't seem to get it to work. As a toy example, consider:

a = torch.FloatTensor([0.4230, 0.5557, 0.3167, 0.6419])
so3 = SO3(a)
b = torch.rand(10, 3)
so3.act(b)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/data/hturki/miniconda3/envs/test/lib/python3.8/site-packages/lietorch-0.1-py3.8-linux-x86_64.egg/lietorch/groups.py", line 163, in act
    return self.apply_op(Act3, self.data, p)
  File "/data/hturki/miniconda3/envs/test/lib/python3.8/site-packages/lietorch-0.1-py3.8-linux-x86_64.egg/lietorch/groups.py", line 120, in apply_op
    inputs, out_shape = broadcast_inputs(x, y)
  File "/data/hturki/miniconda3/envs/test/lib/python3.8/site-packages/lietorch-0.1-py3.8-linux-x86_64.egg/lietorch/broadcasting.py", line 15, in broadcast_inputs
    check_broadcastable(x, y)
  File "/data/hturki/miniconda3/envs/test/lib/python3.8/site-packages/lietorch-0.1-py3.8-linux-x86_64.egg/lietorch/broadcasting.py", line 5, in check_broadcastable
    assert len(x.shape) == len(y.shape)
AssertionError

And note that so3.act(b.T) doesn't fail but seems to silently return None. Am I calling this operation incorrectly?

hturki commented 3 years ago

Also, more broadly speaking, this is the sort of function I was hoping to implement with lietorch:

def get_rays(directions, c2w):
    # Rotate ray directions from camera coordinate to the world coordinate
    rays_d = directions @ c2w[:, :, :3].transpose(1, 2)  # (n, H*W, 3)
    # The origin of all rays is the camera origin in world coordinate
    rays_o = c2w[:, :, 3].unsqueeze(1).expand(rays_d.shape)  # (n, H*W, 3)
    return torch.cat((rays_o, rays_d), -1)

Where directions is an (N, M, 3) tensor and c2w would be an SE3 transform. Note that the rotation and translation components of the transform are treated separately. So ideally I'd be able to:

The easiest solution I could think of was to call the matrix() function on the SE3 transform, but doesn't currently seem to be differentiable.

zachteed commented 3 years ago

Hi, that function is possible to implement. In your code, you just need to add a dimension to so3

a = torch.FloatTensor([0.4230, 0.5557, 0.3167, 0.6419])
b = torch.rand(10, 3)
so3 = SO3(a)
so3[None].act(b)

or if you were acting on an even higher dimensional point cloud

a = torch.FloatTensor([0.4230, 0.5557, 0.3167, 0.6419])
b = torch.rand(10,10,10, 3)
so3 = SO3(a)
so3[None,None,None].act(b)

Right now, broadcasting assumes that both the group element and point have the same number of dimension, so you have to manually expand the dimensions you want to broadcast over.

Leerw commented 2 years ago

hi, I want to implement a toy example for optimizing rotation, here is my code, but it did not converge, are there any wrong things in my code?

import torch

from pytorch3d.transforms import random_rotation
from lietorch import SO3

x = torch.randn(100, 3, requires_grad=False).cuda()
rotation = random_rotation().cuda().requires_grad_(False)
print(rotation)
y = x @ rotation

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.phi = torch.nn.Parameter(torch.randn(1, 3, requires_grad=True).cuda())

    def forward(self, x):
        so3 = SO3(self.phi)
        return so3.act(x)

model = Model().cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

for i in range(1000):
    y_hat = model(x)
    loss = torch.nn.functional.mse_loss(y_hat, y.detach())
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()
    print(loss.item())

print(SO3(model.phi).matrix())