facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.64k stars 1.29k forks source link

How to render meshes with holes correctly with differentiability? #275

Closed AndresCasado closed 4 years ago

AndresCasado commented 4 years ago

Hi. After learning how to correctly use the differentiable render (after my previous issue #236 ) I've been able to get some results for my project, but now I've hit another wall, as it seems meshes with holes are problematic.

Example (code and .obj below) Rendering a sphere with some holes returns different shaded results. The top row are rendered with a differentiable render, while the bottom row are rendered with a normal renderer.

Click to see full image Imgur

It looks to me like there is a problem with the z-depth of the fragments, but I can't pinpoint where. I've searched the issues but I haven't seen an exact solution. I think it could be related to #116 #149.

I've tried some things:

These things are what make me think this has something to do with the z-depth of the fragments, although it's strange that the 6th column is correct.

Before noticing this problem with holes I already had some suspicions about my settings, because the shading of my differentiable images was strange (first and last columns would still be different even without holes). So this may be a symptom of something else.

Reproduce

I've uploaded the .obj file to https://pastebin.com/G0s47z0V

import numpy as np
import pytorch3d.io as torch3d_io
import pytorch3d.renderer as torch3d_render
import pytorch3d.structures as torch3d_struct
import torch

device = torch.device('cuda')

mesh_filename = 'hole_ico.obj'
icomesh = torch3d_io.load_objs_as_meshes([mesh_filename], device=device)

number = 7
icomesh = icomesh.extend(number)
R, T = torch3d_render.look_at_view_transform(
    dist=5.0,
    elev=30.0,
    azim=torch.linspace(0.0, 360.0 - 360.0 / number, number),
    device=device,
)

cameras = torch3d_render.OpenGLPerspectiveCameras(
    R=R, T=T,
    device=device,
)

lights = torch3d_render.PointLights(
    location=((-5, 0, 5),),
    device=device,
)

color = torch.ones_like(icomesh.verts_padded(), device=device)
icomesh.textures = torch3d_struct.Textures(
    verts_rgb=color,
)

normal_raster_settings = torch3d_render.RasterizationSettings(
    image_size=256,
    blur_radius=0.0,
    faces_per_pixel=1,
    bin_size=None,
    max_faces_per_bin=None,
    # cull_backfaces=True,
)

normal_blend_params = torch3d_render.BlendParams(
    sigma=1e-9, gamma=1e-9,
)

normal_renderer = torch3d_render.MeshRenderer(
    rasterizer=torch3d_render.MeshRasterizer(
        cameras=cameras,
        raster_settings=normal_raster_settings,
    ),
    shader=torch3d_render.SoftPhongShader(
        device=device,
        cameras=cameras,
        lights=lights,
        blend_params=normal_blend_params,
    )
)

diff_blend_params = torch3d_render.BlendParams(
    sigma=1e-6, gamma=0.00001,
    background_color=(1.0, 1.0, 1.0),
)

diff_raster_settings = torch3d_render.RasterizationSettings(
    image_size=256,
    blur_radius=np.log(1. / 1e-4 - 1.) * diff_blend_params.sigma,
    faces_per_pixel=10,
    bin_size=None,
    max_faces_per_bin=None,
    # cull_backfaces=True,
)

diff_renderer = torch3d_render.MeshRenderer(
    rasterizer=torch3d_render.MeshRasterizer(
        cameras=cameras,
        raster_settings=diff_raster_settings,
    ),
    shader=torch3d_render.SoftPhongShader(
        cameras=cameras,
        lights=lights,
        blend_params=diff_blend_params,
        device=device,
    )
)

normal_img = normal_renderer(icomesh)
diff_img = diff_renderer(icomesh)

import torchvision as tv

for image, path in [(normal_img, "normal.png"), (diff_img, "diff.png")]:
    tv.utils.save_image(image.transpose(1, 3).transpose(2, 3), path)
AndresCasado commented 4 years ago

I added thickness to the mesh so all the visible faces have the same orientation, but it didn't solve the problem. So I don't think it's related to the face orientation.

Imgur

Using the mesh with thickness plus backface culling keeps showing differences between renderings too.

Imgur

(Top row is differentiable rendering in both images)

AndresCasado commented 4 years ago

After checking the zbuffer of the first column I learned that the nearest fragments had positive values, while the farthest fragments had negative values.

Imgur

Then I read the blending function, and the z-buffer related calculations seemed strange to me. If I need more weight for the front faces, why does the function give more weight to the back ones?

So I did an experiment, what if I reverse the function so it gives more weight to the front fragments? This is my code:

z_inv = (fragments.zbuf - zfar) / (znear - zfar) * mask
z_inv_max = torch.max(z_inv, dim=-1).values[..., None]
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)

znear and zfar are initialized to the max and min values of the zbuffer

This is the result:

Imgur

Now the first and last columns are the ones that have the correct result!

This reinforces my assumption that the problem is z-buffer related.

I'm going to check if the z-buffer of all the points of view have the same values (high in front, low in back), because I think a possible cause is that the renderer is using world or object z-depth instead of camera z-depth.

AndresCasado commented 4 years ago

After checking the zbuffers I confirmed the problem is that they were in world coordinates instead of view coordinates. This issue was supposedly fixed (see #210 ) in commit f0ba4c5, but said commit does not belong to any branch, and the file currently in master branch does not look like the code modified in that commit ( https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/renderer/mesh/renderer.py )

Strangely enough, my local install has a file similar to the renderer.py modified in f0ba4c5. After creating my own version of the MeshRenderer with that fix, and reverting the changes I made to the blending function, the differentiable image looks just like it should.

What steps should be taken now? Should I rename this issue? Close it and reopen #210?

nikhilaravi commented 4 years ago

@AndresCasado thanks for pointing this out. The barycentric clipping and zbuf interpolation has now been moved to the rasterization cuda kernel. There is a setting called clip_barycentric_coordinates in the RasterizationSettings. Have you set this to True? The default is False.

Previously in https://github.com/facebookresearch/pytorch3d/commit/f0ba4c53e58f20ab13db5f4f3a283b5f03b8d8ed we had set barycentric clipping to happen automatically whenever the blur_radius was greater than 0 but now it has to be set manually in the RasterizationSettings.

Let me know if this resolves your issue!

AndresCasado commented 4 years ago

Thanks for the explanation. Now I understand why the code is different. I had an old version installed. I thought I updated it to make sure it wasn't an already solved issue but it seems I messed up somewhere, sorry.

I've just reinstalled PyTorch from master and it works just fine, so issue solved. Thanks again, @nikhilaravi.