NVIDIAGameWorks / kaolin

A PyTorch Library for Accelerating 3D Deep Learning Research
Apache License 2.0
4.34k stars 538 forks source link

Issues in DIB-R rasterization #754

Closed Ligo04 closed 1 week ago

Ligo04 commented 11 months ago

img

I am rendering the two facades of the building obj with textures, and I have encountered some problems, such as the four vertices in the middle black area on the right are obviously wrong, and I feel that the depth value does not work, and there is no perspective effect (near large and far small), I Wondering why this is, the camera is defined with a PinholeIntrinsics with a custom focal length.Here is the code reference:

def render(camera:kaolin.render.camera.Camera):
    vertices_cam = camera.extrinsics.transform(mesh.vertices)           # world-> camera 
    vertices_img = camera.intrinsics.transform(vertices_cam)[...,0:2]   # camera -> ndc
    # pre face
    face_vertices_camera = kaolin.ops.mesh.index_vertices_by_faces(vertices_cam, mesh_faces)
    face_vertices_img = kaolin.ops.mesh.index_vertices_by_faces(vertices_img,mesh_faces)
    # normal
    face_normal_z= kaolin.ops.mesh.face_normals(face_vertices_camera,unit=True)
    # uv
    mesh_faces_uv = kaolin.ops.mesh.index_vertices_by_faces(mesh.uvs,mesh_face_uvs_idx[0])
    face_attributes = mesh_faces_uv
    texture_coords,softmask,face_idx = kaolin.render.mesh.dibr_rasterization(camera.height,camera.width,face_vertices_camera[:,:,:,-1],
                                                            face_vertices_img,face_attributes,face_normal_z,
                                                            rast_backend='cuda')
    softmask = torch.flip(softmask,dims=(2,))       # filp width

    texture_coords = torch.flip(texture_coords,dims=(2,)) 
    face_idx = torch.flip(face_idx,dims=(2,))     

    hard_mask = (face_idx > -1)[..., None]

    in_material_idx = mesh_material_assignments[0,...][face_idx]
    in_material_idx[face_idx == -1] = -1
    # result image
    img = torch.zeros((1,camera.height,camera.width,3), dtype=torch.float, device='cuda')
    for i,material in enumerate(materials):
        mask = in_material_idx == i
        # dense
        _texcoords = torch.where(mask.unsqueeze(3), texture_coords, torch.tensor(0.0).cuda())
        if _texcoords[_texcoords>0].shape[0] > 0:
            print(f'material {i}')
            pixel_val = kaolin.render.mesh.texture_mapping(_texcoords.contiguous(),
                                                           material.permute(0,3,1,2).contiguous(),mode='bilinear')      # origin point in left-down side
            img[mask] = pixel_val[mask]
    return torch.clamp(img * hard_mask,0.0,1.0)
Caenorst commented 8 months ago

Hi @Ligo04 , I can't see your image for some reason. currently our dibr_rasterization method is lacking perspective correction so I would suggest to use nvdiffrast which our camera API is fully compatible with. See an example here: https://github.com/NVIDIAGameWorks/kaolin/blob/master/examples/tutorial/interactive_visualizer.ipynb

Ligo04 commented 8 months ago

Hi @Ligo04 , I can't see your image for some reason. currently our dibr_rasterization method is lacking perspective correction so I would suggest to use nvdiffrast which our camera API is fully compatible with. See an example here: https://github.com/NVIDIAGameWorks/kaolin/blob/master/examples/tutorial/interactive_visualizer.ipynb您好,由于某种原因我看不到您的图片。目前我们的 dibr_rasterization 方法缺乏透视校正,因此我建议使用与我们的相机 API 完全兼容的 nvdiffrast。请参阅此处的示例:https://github.com/NVIDIAGameWorks/kaolin/blob/master/examples/tutorial/interactive_visualizer.ipynb

OK. Thanks! I haved used the nvdiffrast some time.

shumash commented 1 week ago

FYI, we have also rendered kaolin.render.easy_render.render_mesh (sample usage here) which has an option to run with either dibr ("cuda") or "nvdiffrast" back end. Closing bug; please create a new one if you run into issues.