facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.76k stars 1.31k forks source link

The rendering result is abnormal, after updating the Textures Class. #334

Closed liuzhihui2046 closed 4 years ago

liuzhihui2046 commented 4 years ago

After I updated the pytorch3d version(from 0.2.0 ), which Textures class has been DEPRECATED, the rendered face result turned black. Below is my code,other parameters remain unchanged except the textures.

     self.raster_settings = RasterizationSettings(
            image_size=256,
            blur_radius=0.0,
            bin_size=0,
            faces_per_pixel=1,
        )

        phong_renderer = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=cameras,
                raster_settings=self.raster_settings
            ),
            shader=SoftPhongShader(device=device, cameras=cameras)
        )

       #textures = Textures(verts_rgb=verts_rgb[:, :, [0, 1, 2]].to(device))
        textures = TexturesVertex(verts_features=verts_rgb[:, :, [0, 1, 2]]).to(device)
        meshes = Meshes(
            verts=final_verts,
            faces=faces_expand,
            textures=textures
        )
        image = phong_renderer(meshes_world=meshes)

Has anyone encountered a similar problem?

WMCh commented 4 years ago

Class Textures has been moved to renderer now, so you can from pytorch3d.renderer import Textures.

liuzhihui2046 commented 4 years ago

Class Textures has been moved to renderer now, so you can from pytorch3d.renderer import Textures.

I imported the Textures and it can run, but the result is abnormal. When I changed to hardphongshader, it works.

gkioxari commented 4 years ago

Hi @liuzhihui2046 How exactly is the result unexpected. You don't provide a mesh that would allow us to reproduce the results nor do you provide images of the unexpected results. If you need assistance the best we can help you is if you provide code & data so that we can reproduce your issue.

liuzhihui2046 commented 4 years ago

@gkioxari Sorry for not responding in time.I use 3dmm meshes to render a face. If I use softPhoneShader, the rendered face is all black, and it is normal if I change it to hardPhoneShader. I think this bug should have nothing to do with mesh, and it can be reproduced like other meshes.

badcase2

badcase1

The complete code is as follows

        self.blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))

        self.R, self.T = look_at_view_transform(dist=1000, elev=0, azim=0, device="cuda")
        self.raster_settings = RasterizationSettings(
            image_size=256,
            blur_radius=0.0,
            bin_size=0,
            faces_per_pixel=1,
        )
        cameras = PerspectiveCameras(device=device, focal_length=focal_lenghts, principal_point=trans_t, R=self.R,
                                     T=self.T)

        pointlights = PointLights(device=device, location=light_dirc)

        phong_renderer = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=cameras,
                raster_settings=self.raster_settings
            ),
            shader=SoftPhongShader(device=device,
                                     blend_params=self.blend_params,
                                   cameras=cameras)
        )

        faces_expand = self.faces.expand(project_lmk.size(0), self.faces.size(0), self.faces.size(1)).to(device)

        final_verts = project_lmk # mesh verts
        verts_rgb = torch.ones_like(final_verts) # mesh colors

        #textures = Textures(verts_rgb=verts_rgb[:, :, [0, 1, 2]].to(device))
        textures = TexturesVertex(verts_features=verts_rgb[:, :, [0, 1, 2]]).to(device)
        meshes = Meshes(
            verts=final_verts,
            faces=faces_expand,
            textures=textures
        )

        image = phong_renderer(meshes_world=meshes)
FengQiaojun commented 4 years ago

I also met this problem. It happened when using SoftPhongShader with faces_per_pixel=1 in raster_setting. The mesh to render is single layer, just like the face model above. (Not sure whether other numbers of faces_per_pixel will work.) I further checked the commits and I believe this happens after commit 5852b74. Before this commit the SoftPhongShader works just file (in commit 8e9ff15). Probably something in the blending.py is broken. I can share the .obj mesh I used for testing.

Left: the rendered mesh. Middle: the texture map as a reference. Right: (ignore this). Successful case: Screenshot from 2020-09-02 15-41-30 Fail case: Screenshot from 2020-09-02 15-41-47

sbranson commented 4 years ago

Thanks for the detailed examples. I think this behavior is happening because the scale of the meshes in these examples is larger than normal (mesh vertex values are on order -1000 to 1000), which is causing an overflow in blending.py. Can you try scaling the mesh to be smaller? It is sometimes the case that meshes are scaled so that vertices are on the order -1 to 1. They definitely have to be smaller than the default value of zfar=100 to avoid this problem.

liuzhihui2046 commented 4 years ago

Thanks for the detailed examples. I think this behavior is happening because the scale of the meshes in these examples is larger than normal (mesh vertex values are on order -1000 to 1000), which is causing an overflow in blending.py. Can you try scaling the mesh to be smaller? It is sometimes the case that meshes are scaled so that vertices are on the order -1 to 1. They definitely have to be smaller than the default value of zfar=100 to avoid this problem.

It's works! Thank you very much. It may be that the dist value I set is too large. I originally set it to 1000. The scale of verts is about -100 to 100. Then I adjusted the dist = 100, and the corresponding verts scale was reduced to -10 to 10 and the result is right. I think it is caused by the reasons you mentioned above.

liuzhihui2046 commented 4 years ago

@sbranson Although adjusting the scales of vertices can solve this problem, I think this is still a bug. When we load an obj, we need to normalize the vertices to the right scale to get the correct result, otherwise the above situation will occur, which is not very reasonable. what do you think?

FengQiaojun commented 4 years ago

Thanks for the suggestions! It worked. However, as @liuzhihui2046 mentioned, the tool should be able to adapt to different scales. So how can we change the zfar parameter for the blending?

sbranson commented 4 years ago

Thanks, I agree there's a bug and we are hoping to release a fix soon (probably within the next week).

The issue has to do with what happens when triangles occur beyond the far clipping plane (zfar). Renderers like OpenGL will clip triangles beyond zfar, such that they aren't visible. PyTorch3D is a bit inconsistent on what it will do, and the triangles aren't always clipped. It is a bug that the triangles beyond zfar are appearing black (they should be clipped). It should also be the case that one should be able to increase cameras.zfar without rescaling the mesh, but there is a bug in that the zfar parameter isn't propagating to the blending function.

We're hoping to have a fix in very soon, but if you want to temporarily fix it you can manually increase the default value of zfar to softmax_rgb_blend() in blending.py. It might be a good idea to instead scale down the mesh to be safe. If you're using PyTorch3D for differentiable rendering, there are various parameters and loss functions, and I'm not totally sure if all of them are invariant to scale (i.e., there are some default parameter choices that might be better suited for smaller meshes).

Anyway, thanks to you both for helping to find this issue.

nikhilaravi commented 4 years ago

This issue should have been fixed by https://github.com/facebookresearch/pytorch3d/commit/f8ea5906c0ae5ef6fb7800e3f0a05ebf56cdd927. Please reopen if you have further questions!