facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.7k stars 1.3k forks source link

get_full_projection_transform not functioning properly for batch sizes > 2 #1843

Closed simonvw95 closed 2 months ago

simonvw95 commented 2 months ago

🐛 Bugs / Unexpected behaviors

Pytorch3d version==0.7.6 python version==3.10.4 torch version==2.1.2+cu118

I have been using pytorch3d to compute perspective projection matrices. Now I'm working with batched data and have encountered an issue with this version of pytorch3d: When using a PerspectiveCameras object, with a given focal length, given batched rotation matrix R and given batched rotation matrix T, the get_full_projection_transform function performs as intended and returns a transform3d object. However, when I use the get_matrix() function to acquire the batched transformation matrices from the prior result I get the following error:

  File "C:\aaaa\venv\lib\site-packages\pytorch3d\transforms\transform3d.py", line 248, in get_matrix
    composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
  File "C:\aaaa\venv\lib\site-packages\pytorch3d\transforms\transform3d.py", line 876, in _broadcast_bmm
    raise ValueError(msg % (a.shape, b.shape))
ValueError: Expected batch dim for bmm to be equal or 1; got torch.Size([10, 4, 4]), torch.Size([2, 4, 4])

Upon further inspection, the get_matrix() function appears to be working correctly as each individual line works, apart from one. The error lies in one the transforms in the result of the get_full_projection_transform.

world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
view_to_proj_transform = self.get_projection_transform(**kwargs)

The self.get_projection_transform(**kwargs), regardless of the number of samples in the batch, always returns a matrix of 2 samples:

tensor([[[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 1., 0.]],
        [[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 1., 0.]]])

An easy fix for this is to simply ignore the get_matrix() call from the result of get_full_projection_transform and to apply the steps of get_matrix() individually and to replace the matrix from view_to_proj_transform with one of the proper size.

Instructions To Reproduce the Issue:

when setting the batch size >2, for instance 10 as shown below, we get the error as shown above.

import torch
from pytorch3d.renderer import (look_at_view_transform, PerspectiveCameras)
from pytorch3d.transforms.transform3d import _broadcast_bmm

batch_size = 10
angles = torch.randn((batch_size, 2))

elevation = angles[:, 0]
azimuth = angles[:, 1]
camera = PerspectiveCameras(focal_length = torch.tensor([1,1]).float())

# get batched transformation matrices based on batched elevations and azimuths
R, T = look_at_view_transform(1, elevation.float(), azimuth.float())
transform_matrix = camera.get_full_projection_transform(R=R.float(),T=T.float()).get_matrix()

A simple solution is the following:

import torch
from pytorch3d.renderer import (look_at_view_transform, PerspectiveCameras)
from pytorch3d.transforms.transform3d import _broadcast_bmm

batch_size = 2
angles = torch.randn((batch_size, 2))

elevation = angles[:, 0]
azimuth = angles[:, 1]
camera = PerspectiveCameras(focal_length = torch.tensor([1,1]).float())

# get batched transformation matrices based on batched elevations and azimuths
R, T = look_at_view_transform(1, elevation.float(), azimuth.float())
res_tf3d_obj = camera.get_full_projection_transform(R=R.float(),T=T.float())

composed_matrix = res_tf3d_obj._matrix.clone()
n_in_batch = R.shape[0]
other_matrix = res_tf3d_obj._transforms[0].get_matrix()
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
other_matrix = torch.tensor([[1., 0., 0., 0.],
                           [0., 1., 0., 0.],
                           [0., 0., 0., 1.],
                           [0., 0., 1., 0.]]).float()
other_matrix = other_matrix.repeat(batch_size, 1, 1).float()
transform_matrix = _broadcast_bmm(composed_matrix, other_matrix)
bottler commented 2 months ago

You need to have a single batch size for the camera and R and T. The error could be friendlier but it is correct.

The line

camera = PerspectiveCameras(focal_length = torch.tensor([1,1]).float())

creates a batch of two cameras in fact, but it's recommended to pass a 2D tensor in if you want to use a tensor as the focal length. I think it would be easier to do

camera = PerspectiveCameras(focal_length = torch.ones(batch_size, 1, dtype=torch.float32))
simonvw95 commented 2 months ago

Ah I completely overlooked the fact that the initialization of the focal length could be the problem! Thank you. If the given batch size stays constant then this is the best way to do it yes. However, if a dataset has a set number of samples that is incompatible with a fixed batch size then initializing this at the start would be insufficient, e.g. dataset has 100 samples, batch size of 24 will result in 4 batches with size 24 and one batch of size 4.

I suppose one can create a new camera object solely for these edge cases. Regardless, thank you for noting the actual problem!