facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.53k stars 1.28k forks source link

FoVPerspectiveCameras handles user-supplied K matrix in an unexpected way #1255

Closed mikeroberts3000 closed 2 years ago

mikeroberts3000 commented 2 years ago

Hello there, I'm experiencing a subtle issue with the FoVPerspectiveCameras class that only occurs when I supply my own K matrix, and only for non-square images.

I'm following the render_textured_meshes tutorial example exactly, but I'm experimenting with the following two modifications. First, I am attempting to change the RasterizationSettings object to render non-square images. Second, I am attempting to supply my own perspective projection matrix to the FoVPerspectiveCameras class. I expect that both modifications should be possible at the same time, but I'm getting what I think is incorrect rendering output in some cases.

In my experiments, I compute my projection matrix exactly the same as the glFrustum OpenGL function, as documented here [1,2]. I have used the exact same code to pixel-perfectly match the rendering output from V-Ray (a commercial rendering engine) throughout the Hypersim project [3], so I consider my code to be a correct reimplementation of glFrustum. I take care to account for the fact that PyTorch3D's NDC coordinate conventions differ from OpenGL's, as documented here [4].

I found that for square images, I get identical rendering output when I supply my own K matrix and when I let FoVPerspectiveCameras compute its own K matrix. So far, so good.

For non-square images, I get the correct rendering output when I let FoVPerspectiveCameras compute its own K matrix. I can tell it's correct because the rendering output is pixel-perfect identical to the square image case, but with more white pixels on the left and right borders of the image. This is exactly what you would expect if you keep the vertical field-of-view fixed, keep the image height fixed, and increase the image width, which is what I'm doing in my experiment. Again, so far, so good.

But I get incorrect rendering output when I have a non-square image and I supply my own K matrix. This is the case even though I correctly (i.e., my implementation matches glFrustum) account for the image's aspect ratio in my code that computes the projection matrix.

Here is the rendered output when [height, width] == [384, 384]. It is the same for an automatically computed K and user-supplied K:

ref_384_384

Here is the rendered output when [height, width] == [384, 512], and the field-of-view is fixed, for an automatically computed K. In this case, I would expect the cow to be rendered pixel-for-pixel identically to the image above, but with more white pixels on the left and right of the image, and that is exactly what we get:

ref_384_512

But here is the rendered output when [height, width] == [384, 512], and the field-of-view is fixed, when I supply my own K. In this case, I would expect an identical image to the one immediately above, but that is not what I get:

user_K_384_512

Is it expected that I that an OpenGL projection matrix (e.g., as computed by glFrustum) can be supplied directly to the FoVPerspectiveCameras constructor as a user-supplied K matrix, assuming the user takes care to flip the sign of the (3,2) entry to account for PyTorch3D's different NDC conventions? If so, I believe the behavior I'm experiencing is a bug in PyTorch3D, because I'm providing a correct (with respect to the expectation above) K matrix that agrees with the glFrustum documentation and accounts for the different NDC conventions, but I'm not getting the correct rendered output.

I'll post code snippets in a follow-up post below.

[1] https://docs.microsoft.com/en-us/windows/win32/opengl/glfrustum [2] http://nehe.gamedev.net/article/replacement_for_gluperspective/21002/ [3] https://github.com/apple/ml-hypersim [4] https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/renderer/cameras.py#L596

mikeroberts3000 commented 2 years ago

Here is my code for letting FoVPerspectiveCameras compute its own K matrix:

# Initialize a camera.

# look at the tip of the right brown bump on the cow's head
R, T = look_at_view_transform(eye=[[0.0, 5.0, 5.0]], up=[[0.0, 1.0, 0.0]], at=[[0.182289, 0.937539, -0.304334]])
print(R)
print(T)

# set each dimension to be half the size of a Hypersim image
width_pixels  = 512
height_pixels = 384
fov_y         = np.pi / 3.0

# aspect_ratio here is the aspect ratio of an individual pixel so we set to 1.0 regardless of width and height
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=fov_y, degrees=False, aspect_ratio=1.0)

raster_settings = RasterizationSettings(
    image_size=[height_pixels, width_pixels],
    blur_radius=0.0, 
    faces_per_pixel=1, 
)

# Place a point light in front of the object. As mentioned above, the front of the cow is facing the 
# -z direction. 
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])

# Create a Phong renderer by composing a rasterizer and a shader. The textured Phong shader will 
# interpolate the texture uv coordinates for each vertex, sample from a texture image and 
# apply the Phong lighting model
renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=cameras,
        lights=lights
    )
)

images = renderer(mesh)
plt.imsave("ref_384_512.png", images[0, ..., :3].cpu().numpy())

plt.figure(figsize=(10, 10))
plt.imshow(images[0, ..., :3].cpu().numpy())
plt.axis("off");

Here is my code for supplying my own K matrix:

# Initialize a camera.

# look at the tip of the right brown bump on the cow's head
R, T = look_at_view_transform(eye=[[0.0, 5.0, 5.0]], up=[[0.0, 1.0, 0.0]], at=[[0.182289, 0.937539, -0.304334]])
print(R)
print(T)

# set each dimension to be half the size of a Hypersim image
width_pixels  = 512
height_pixels = 384
fov_y         = np.pi / 3.0

near = 1.0
far  = 100.0

# construct projection matrix
f_h    = np.tan(fov_y/2.0)*near
f_w    = f_h*width_pixels/height_pixels
left   = -f_w
right  = f_w
bottom = -f_h
top    = f_h

M_proj      = np.matrix(np.zeros((4,4)))
M_proj[0,0] = (2.0*near)/(right - left)
M_proj[1,1] = (2.0*near)/(top - bottom)
M_proj[0,2] = (right + left)/(right - left)
M_proj[1,2] = (top + bottom)/(top - bottom)
M_proj[2,2] = -(far + near)/(far - near)
M_proj[3,2] = -1.0
M_proj[2,3] = -(2.0*far*near)/(far - near)

M_pytorch_from_opengl = np.matrix(np.identity(4))
M_pytorch_from_opengl[3,3] = -1

K = (M_pytorch_from_opengl*M_proj).A

print(M_pytorch_from_opengl)
print(M_proj)
print(K)

# aspect_ratio here is the aspect ratio of an individual pixel so we set to 1.0 regardless of width and height
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=fov_y, degrees=False, aspect_ratio=1.0, K=[K])

# also doesn't work
# cameras = FoVPerspectiveCameras(device=device, R=R, T=T, K=[K])

raster_settings = RasterizationSettings(
    image_size=[height_pixels, width_pixels],
    blur_radius=0.0, 
    faces_per_pixel=1,
)

# Place a point light in front of the object. As mentioned above, the front of the cow is facing the 
# -z direction. 
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])

# Create a Phong renderer by composing a rasterizer and a shader. The textured Phong shader will 
# interpolate the texture uv coordinates for each vertex, sample from a texture image and 
# apply the Phong lighting model
renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=cameras,
        lights=lights
    )
)

images = renderer(mesh)
plt.imsave("user_K_384_512.png", images[0, ..., :3].cpu().numpy())

plt.figure(figsize=(10, 10))
plt.imshow(images[0, ..., :3].cpu().numpy())
plt.axis("off");

Notice that I'm accounting for the image aspect ratio exactly as described in the documentation for glFrustum.

mikeroberts3000 commented 2 years ago

Update: I was able to get the correct rendering output for non-square images. But I needed to set the image width and height to be equal when computing my own K matrix, and then set them to their true values when instantiating the RasterizationSettings object.

This is an important difference between the K matrix in FoVPerspectiveCameras and the standard OpenGL projection matrix. Is this difference intended?

Here is the working code:

# Initialize a camera.

# look at the tip of the right brown bump on the cow's head
R, T = look_at_view_transform(eye=[[0.0, 5.0, 5.0]], up=[[0.0, 1.0, 0.0]], at=[[0.182289, 0.937539, -0.304334]])
print(R)
print(T)

# set each dimension to be half the height of a Hypersim image, even though this is not the intended image size
width_pixels  = 384
height_pixels = 384
fov_y         = np.pi / 3.0

near = 1.0
far  = 100.0

# construct projection matrix
f_h    = np.tan(fov_y/2.0)*near
f_w    = f_h*width_pixels/height_pixels
left   = -f_w
right  = f_w
bottom = -f_h
top    = f_h

M_proj      = np.matrix(np.zeros((4,4)))
M_proj[0,0] = (2.0*near)/(right - left)
M_proj[1,1] = (2.0*near)/(top - bottom)
M_proj[0,2] = (right + left)/(right - left)
M_proj[1,2] = (top + bottom)/(top - bottom)
M_proj[2,2] = -(far + near)/(far - near)
M_proj[3,2] = -1.0
M_proj[2,3] = -(2.0*far*near)/(far - near)

M_pytorch_from_opengl = np.matrix(np.identity(4))
M_pytorch_from_opengl[3,3] = -1

K = (M_pytorch_from_opengl*M_proj).A

print(M_pytorch_from_opengl)
print(M_proj)
print(K)

cameras = FoVPerspectiveCameras(device=device, R=R, T=T, K=[K])

# only set the actual width and height after K has been computed
width_pixels  = 512
height_pixels = 384

raster_settings = RasterizationSettings(
    image_size=[height_pixels, width_pixels],
    blur_radius=0.0, 
    faces_per_pixel=1,
)
gkioxari commented 2 years ago

@mikeroberts3000 Thank you so much for your detailed explanation and code snippet! This is a fantastic example how to properly report issues!

I think I have an understanding of the issue and the confusion. We made some intended design choices for non-square rendering in PyTorch3D which might cause confusion in the context of OpenGL. Let me briefly explain.

Let's assume we talk about NDC coordinates throughout this post.

How to define cameras for non-square renders?

To recap! For square images, K should project the part of the scene you wish to render to [-1, 1]x[-1,1]. This is compatible with most camera conventions out there, including OpenGL. This is why you see no problem in your code!

For non-square images, according to our PyTorch3D conventions, K is expected to project the part of the scene to be rendered to [-u,u]x[-1,1] or [-1, 1]x[-u, u]. If your K does not do that and projects the scene to [-1,1]x[-1,1] then that's an issue!

Let's look at your case! Your OpenGL camera K is constructed to project to [-1, 1]x[-1,1] for both square and non-square. But for (image_height, image_width) = (384, 512), our renderer will render in [-1.33, 1.33]x[-1,1], as we described above. That's the core of the issue and also why you see the cow's but (!) being squashed in the x-axis. To address this, for non-square images your definition of K should project points to the PyTorch3D space, aka [-1.33, 1.33]x[-1,1]. This explains why your K defined as if the image is square with the shortest image side gave you right result, as that K projects to [-1.33, 1.33]x[-1,1].

Why on earth did you make that decision?

Non-square rendering is not very common in differentiable renders. We chose to go with a non-square NDC space in this case because we found it the most natural. Though, I do agree that it causes confusions when you are used to OpenGL conventions! Sorry :(

mikeroberts3000 commented 2 years ago

Hi @gkioxari, thanks for that detailed write-up. This makes total sense.

The good news is that all of these NDC issues (sign-flip, non-square images) can be handled cleanly by left-multiplying an appropriate M_pytorch_from_opengl matrix. In case anyone else needs to do something similar, here is a complete working code snippet for constructing an OpenGL projection matrix, and then constructing an appropriate M_pytorch_from_opengl matrix to perform the conversion to match the PyTorch3D conventions exactly.

# Initialize a camera.

# look at the tip of the right brown bump on the cow's head
R, T = look_at_view_transform(eye=[[0.0, 5.0, 5.0]], up=[[0.0, 1.0, 0.0]], at=[[0.182289, 0.937539, -0.304334]])
print(R)
print(T)

# set width and height to be half the size of a Hypersim image
width_pixels  = 512
height_pixels = 384
fov_y         = np.pi / 3.0

near = 1.0
far  = 100.0

# construct an OpenGL projection matrix the same as how gluPerspective or glFrustum would construct it
f_h    = np.tan(fov_y/2.0)*near
f_w    = f_h*width_pixels/height_pixels
left   = -f_w
right  = f_w
bottom = -f_h
top    = f_h

M_proj      = np.matrix(np.zeros((4,4)))
M_proj[0,0] = (2.0*near)/(right - left)
M_proj[1,1] = (2.0*near)/(top - bottom)
M_proj[0,2] = (right + left)/(right - left)
M_proj[1,2] = (top + bottom)/(top - bottom)
M_proj[2,2] = -(far + near)/(far - near)
M_proj[3,2] = -1.0
M_proj[2,3] = -(2.0*far*near)/(far - near)

# construct a matrix to convert to PyTorch3D conventions from OpenGL conventions
M_pytorch_from_opengl = np.matrix(np.identity(4))

# flip the sign of the bottom row of M_proj to account for the fact that the PyTorch3D NDC space is a different handedness than OpenGL 
M_pytorch_from_opengl[3,3] = -1

# scale the NDC points produced by M_proj to account for PyTorch3D's non-square NDC space
if width_pixels > height_pixels:
    M_pytorch_from_opengl[0,0] = width_pixels/height_pixels
if height_pixels > width_pixels:
    M_pytorch_from_opengl[1,1] = height_pixels/width_pixels

# left-multiply the OpenGL projection matrix by M_pytorch_from_opengl to construct an appropriate K matrix
K = (M_pytorch_from_opengl*M_proj).A

print(M_pytorch_from_opengl)
print(M_proj)
print(K)

# K can be passed into the FoVPerspectiveCameras constructor and perfectly matches the automatically computed K
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, K=[K])

raster_settings = RasterizationSettings(
    image_size=[height_pixels, width_pixels],
    blur_radius=0.0, 
    faces_per_pixel=1,
)

# Place a point light in front of the object. As mentioned above, the front of the cow is facing the 
# -z direction. 
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])

# Create a Phong renderer by composing a rasterizer and a shader. The textured Phong shader will 
# interpolate the texture uv coordinates for each vertex, sample from a texture image and 
# apply the Phong lighting model
renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=cameras,
        lights=lights
    )
)

images = renderer(mesh)
plt.imsave("user_K_384_512.png", images[0, ..., :3].cpu().numpy())

plt.figure(figsize=(10, 10))
plt.imshow(images[0, ..., :3].cpu().numpy())
plt.axis("off");

Here is the resulting rendering output, which pixel-perfectly matches the results for an automatically computed K:

user_K_384_512

luffy-yu commented 7 months ago

@mikeroberts3000 Thank you first!

I ran into a similar problem, but it was from Unity to Pytorch3D.

For anyone who may have the same problem, I wanna share my solution here.

def transform_unity_projection_matrix(projection_matrix):
    max_value = max(projection_matrix[0][0], projection_matrix[1][1])
    projection_matrix[0][0] = max_value
    projection_matrix[1][1] = max_value

    projection_matrix[1, 2] *= -1
    projection_matrix[2, 2] *= -1
    projection_matrix[3, 2] *= -1

    return projection_matrix

Note: I didn't use the aspect ratio to update element [0][0] and element [1][1] here because the recalculated results were not exactly identical. I assign the max value because they are supposed to be identical.