NVlabs / nvdiffrast

Nvdiffrast - Modular Primitives for High-Performance Differentiable Rendering
Other
1.28k stars 139 forks source link

Range Mode Rasterization and Interpolation only works for the first minibatch index #164

Closed raoshashank closed 5 months ago

raoshashank commented 5 months ago

This is based on the use case in https://github.com/NVlabs/diff-dope/issues/3. I'm trying to render RGBD images of YCB objects with different meshes in each minibatch index using the range mode. However, I'm finding that only the first minibatch index is correctly rendered. As instructed in the documentation, I am providing the start index and the number of triangles per batch index as a tensor of Bx2 shape and providing the triangle indices and the point cloud as Nx4 and Mx3 tensors respectively. I suspect that the triangle indices are not being used correctly in the range mode. It would be super helpful if a range mode example is provided in the repo itself.. MWE (uses YCB object meshes):

import os
import os.path as osp
import matplotlib.pyplot as plt
from icecream import ic
import numpy as np
import torch
import pybullet as p
import trimesh
import nvdiffrast.torch as dr
#most of the code is from: https://github.com/NVlabs/diff-dope
def display_image_grid(images):
    B = images.shape[0]

    # Calculate the number of rows and columns for the grid
    rows = int(B**0.5)
    cols = (B + rows - 1) // rows

    # Create a grid of subplots
    fig, axes = plt.subplots(rows, cols, figsize=(10, 10))

    for i in range(rows):
        for j in range(cols):
            index = i * cols + j

            if index < B:
                # Get the RGB image at the current position in the grid
                rgb_image = images[index].numpy()

                # Display the RGB image
                axes[i, j].imshow(rgb_image)
                axes[i, j].axis('off')
            else:
                # If there are fewer images than the grid size, hide the empty subplots
                axes[i, j].axis('off')

    plt.show()
def xfm_points(points,mtx):
    return torch.matmul(
                torch.nn.functional.pad(points.contiguous(), pad=(0, 1), mode="constant", value=1.0),
                torch.transpose(mtx, 1, 2),
            )

def render_texture_batch_range_mode(
    glctx,
    mtx,
    pos,
    pos_clip_ja,
    pos_idx,
    resolution,
    ranges,
    uv=None,
    uv_idx=None,
    tex=None,
    vtx_color=None,
    return_rast_out=False,
):
    if not type(resolution) == list:
        resolution = [resolution, resolution]
    posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).cuda()], axis=1)
    tri = pos_idx.int().contiguous()
    cts_ranges = ranges.int().contiguous()
    assert len(pos_clip_ja.shape) == 2 and len(posw.shape) == 2 and ranges!=None        
    rast_out, rast_out_db = dr.rasterize(
        glctx, pos_clip_ja.contiguous(), tri, resolution=resolution, ranges = cts_ranges
    )

    # compute the depth
    gb_pos, _ = dr.interpolate(
        posw.contiguous(),
        rast_out,
        tri,
        rast_db=rast_out_db,
        diff_attrs= "all",
    )

    shape_keep = gb_pos.shape
    gb_pos = gb_pos.reshape(shape_keep[0], -1, shape_keep[-1])
    gb_pos = gb_pos[..., :3]
    depth = xfm_points(gb_pos.contiguous(),mtx)
    depth = depth.reshape(shape_keep)[..., 2] * -1

    mask, _ = dr.interpolate(torch.ones(tri.shape).cuda(), 
        rast_out, tri,rast_db=rast_out_db,diff_attrs="all") #
    mask = dr.antialias(mask, rast_out, pos_clip_ja.contiguous(), tri)

    # compute vertex color interpolation
    if vtx_color is None:
        texc, texd = dr.interpolate(
            uv, rast_out, uv_idx, rast_db=rast_out_db, diff_attrs="all"
        )
        color = dr.texture(
            tex,
            texc,
            texd,
            filter_mode="linear",
        )

        color = color * torch.clamp(rast_out[..., -1:], 0, 1)  # Mask out background.
    else:
        color, _ = dr.interpolate(vtx_color, rast_out, tri)
        color = color * torch.clamp(rast_out[..., -1:], 0, 1)  # Mask out background.
    if not return_rast_out:
        rast_out = None
    return {"rgb": color, "depth": depth, "rast_out": rast_out, 'mask':mask}

class Mesh(torch.nn.Module):

    def __init__(self, path_model, scale):
        super().__init__()

        # load the mesh
        self.path_model = path_model
        self.to_process = [
            "pos",
            "pos_idx",
            "vtx_color",
            "tex",
            "uv",
            "uv_idx",
            "vtx_normals",
        ]

        mesh = trimesh.load(self.path_model, force="mesh")

        pos = np.asarray(mesh.vertices)
        pos_idx = np.asarray(mesh.faces)

        normals = np.asarray(mesh.vertex_normals)

        pos_idx = torch.from_numpy(pos_idx.astype(np.int32))

        vtx_pos = torch.from_numpy(pos.astype(np.float32)) * scale
        vtx_normals = torch.from_numpy(normals.astype(np.float32))
        bounding_volume = [
            [
                torch.min(vtx_pos[:, 0]),
                torch.min(vtx_pos[:, 1]),
                torch.min(vtx_pos[:, 2]),
            ],
            [
                torch.max(vtx_pos[:, 0]),
                torch.max(vtx_pos[:, 1]),
                torch.max(vtx_pos[:, 2]),
            ],
        ]

        dimensions = [
            bounding_volume[1][0] - bounding_volume[0][0],
            bounding_volume[1][1] - bounding_volume[0][1],
            bounding_volume[1][2] - bounding_volume[0][2],
        ]
        center_point = [
            ((bounding_volume[0][0] + bounding_volume[1][0]) / 2).item(),
            ((bounding_volume[0][1] + bounding_volume[1][1]) / 2).item(),
            ((bounding_volume[0][2] + bounding_volume[1][2]) / 2).item(),
        ]

        if isinstance(mesh.visual, trimesh.visual.texture.TextureVisuals):
            tex = np.array(mesh.visual.material.image) / 255.0
            uv = mesh.visual.uv
            uv[:, 1] = 1 - uv[:, 1]
            uv_idx = np.asarray(mesh.faces)

            tex = torch.from_numpy(tex.astype(np.float32))
            uv_idx = torch.from_numpy(uv_idx.astype(np.int32))
            vtx_uv = torch.from_numpy(uv.astype(np.float32))

            self.pos_idx = pos_idx
            self.pos = vtx_pos
            self.tex = tex
            self.uv = vtx_uv
            self.uv_idx = uv_idx
            self.bounding_volume = bounding_volume
            self.dimensions = dimensions
            self.center_point = center_point
            self.vtx_normals = vtx_normals
            self.has_textured_map = True

        else:
            vertex_color = mesh.visual.vertex_colors[..., :3] / 255.0
            vertex_color = torch.from_numpy(vertex_color.astype(np.float32))

            self.pos_idx = pos_idx
            self.pos = vtx_pos
            self.vtx_color = vertex_color
            self.bounding_volume = bounding_volume
            self.dimensions = dimensions
            self.center_point = center_point
            self.vtx_normals = vtx_normals
            self.has_textured_map = False

        self._batchsize_set = False

    def __str__(self):
        return f"mesh @{self.path_model}. vtx:{self.pos.shape} on {self.pos.device}"

    def __repr__(self):
        return f"mesh @{self.path_model}. vtx:{self.pos.shape} on {self.pos.device}"

    def set_batchsize(self, batchsize):
        """
        Set the batchsize of the mesh object to match the optimization.

        Args:
            batchsize (int): batchsize for the arrays used by nv diff rast

        """

        for key, value in vars(self).items():
            if not key in self.to_process:
                continue
            if self._batchsize_set is False:
                vars(self)[key] = torch.stack([vars(self)[key]] * batchsize, dim=0)
            else:
                vars(self)[key] = torch.stack([vars(self)[key][0]] * batchsize, dim=0)

        for key, value in self._parameters.items():
            if not key in self.to_process:
                continue
            if self._batchsize_set is False:
                self._parameters[key] = torch.stack(
                    [self._parameters[key]] * batchsize, dim=0
                )
            else:
                self._parameters[key] = torch.stack(
                    [self._parameters[key][0]] * batchsize, dim=0
                )

        if self._batchsize_set is False:
            self._batchsize_set = True

    def cuda(self):
        """
        put the arrays from `to_process` on gpu
        """
        super().cuda()

        for key, value in vars(self).items():
            if not key in self.to_process:
                continue
            vars(self)[key] = vars(self)[key].cuda()

    def enable_gradients_texture(self):
        """
        Function to enable gradients on the texture *please note* if `set_batchsize` is called after this function the gradients are set to false for the image automatically
        """
        if self.has_textured_map:
            self.tex = torch.nn.Parameter(self.tex, requires_grad=True).to(
                self.tex.device
            )
        else:
            self.vtx_color = torch.nn.Parameter(self.vtx_color, requires_grad=True).to(
                self.vtx_color.device
            )

    def forward(self):
        """
        Pass the information from the mesh back to diff-dope defined in the the `to_process`
        """
        to_return = {}
        for key, value in vars(self).items():
            if not key in self.to_process:
                continue
            to_return[key] = vars(self)[key]
        # if not 'tex' in to_return:
        #     to_return['tex'] = self.tex
        # elif not 'vtx_color' in to_return:
        #     to_return['vtx_color'] = self.vtx_color
        return to_return

def batch_mesh_forward(meshes,object_names):
    props = ['pos','pos_idx','uv','uv_idx']
    mesh_result = {}
    count = {}
    start_index= {}
    #meshes = [self.mesh[o]() for o in object_names]
    meshes = {o: meshes[o]() for o in list(set(object_names))}
    for p in props:    
        count[p] = []
        mesh_result[p] = []
        for n in object_names:
            mesh_result[p].append(meshes[n][p])
            count[p].append(meshes[n][p].shape[0]) 
        mesh_result[p] = torch.concat(mesh_result[p],dim=0).cuda()
        count[p] = torch.tensor(count[p])
        start_index[p] = torch.roll(torch.cumsum(count[p],0),shifts=1,dims=0)
        start_index[p][0]=0

    mesh_result['tex'] = torch.stack([meshes[n]['tex'] for n in object_names])
    mesh_result['ranges'] = torch.vstack((start_index['pos_idx'],count['pos_idx'])).T #shape = num of trianges x 2
    return mesh_result,count,start_index

if __name__ == '__main__':
    glctx = dr.RasterizeGLContext()
    camera_params = {'cam_eye': np.array([-0.59347291,  0.43063584,  0.64322782]),
    'lookat': np.array([0., 0., 0.]),
    'up': np.array([0., 0., 1.]),
    'width': 480,
    'height': 480,
    'nearVal': 0.1,
    'farVal': 100.0,
    'fov': 60.0
    }
    object_names = ['035_power_drill','040_large_marker','021_bleach_cleanser','025_mug']
    textured_trimesh_object_meshes = {}
    nvdr_meshes = {}
    for o in object_names:
        textured_trimesh_object_meshes[o]=trimesh.load(os.path.join('../ycb',o,'textured_simple.obj'))
        nvdr_meshes[o] = Mesh(textured_trimesh_object_meshes[o],1.0)
    #set the camera at origin and object 0.3 units away
    view_matrix = np.asarray(p.computeViewMatrix(camera_params['cam_eye'],camera_params['lookat'],camera_params['up'])).reshape([4,4],order='F')
    projection_matrix = np.asarray(p.computeProjectionMatrixFOV(fov = camera_params['fov'],
                                            aspect = 1.0*camera_params['width']/camera_params['height'],
                                            nearVal = camera_params['nearVal'],farVal=camera_params['farVal']),order='F').reshape([4,4],order='F')
    camera_params['projection_matrix'] = torch.from_numpy(projection_matrix).unsqueeze(0).repeat(4,1,1).float().cuda()
    camera_params['view_matrix'] = torch.from_numpy(view_matrix).unsqueeze(0).repeat(4,1,1).float().cuda()

    obj_pose_T44_xstart = torch.eye(4)
    obj_pose_T44_xstart[2,-1]-=0.3
    obj_pose_T44_xstart = obj_pose_T44_xstart.unsqueeze(0).repeat(4,1,1).float().cuda()
    mesh_result,count,start_index = batch_mesh_forward(nvdr_meshes,object_names)

    #cam-view pose to NDC pose
    final_mtx_proj = torch.matmul(camera_params['projection_matrix'], obj_pose_T44_xstart)
    batchsize = 4
    #point cloud is correct
    pos_clip_ja = torch.zeros(mesh_result['pos'].shape[0],4).cuda()
    for i in range(batchsize):
        s = start_index['pos'][i]
        e = s + count['pos'][i]
        pos_clip_ja[s:e] = xfm_points(mesh_result['pos'][s:e].unsqueeze(0).contiguous(),final_mtx_proj[i].unsqueeze(0))[0]

    renders = render_texture_batch_range_mode(
                glctx=glctx,
                mtx=obj_pose_T44_xstart,
                pos=mesh_result['pos'],
                pos_clip_ja = pos_clip_ja,
                pos_idx=mesh_result['pos_idx'],
                uv=mesh_result['uv'],
                uv_idx=mesh_result['uv_idx'],
                tex = mesh_result['tex'].cuda(),
                ranges=mesh_result['ranges'],
                resolution=[480,480]
            )         
    display_image_grid(renders['rgb'].cpu())

results: image

s-laine commented 5 months ago

Based on the images, it looks like all meshes are referencing the same set of vertices. Have you offset the vertex indices in the meshes' triangles to account for their different starting points in the vertex array?

raoshashank commented 5 months ago

do you mean I should be adding a constant to all the mesh indices for all the batches after the first one based on their position? for example, for the mesh indices for batch 3, I should add N_1_3 = number of triangles in mesh 1 + number of triangles in mesh 2?

s-laine commented 5 months ago

Yes, the triangles should reference the actual indices of the vertices in the position array.

raoshashank commented 5 months ago

Can confirm, this works. For future reference, the batch mesh render function is to be updated as follows:

def batch_mesh_forward(meshes,object_names):
    props = ['pos','pos_idx','uv','uv_idx']
    mesh_result = {}
    count = {}
    start_index= {}
    meshes = {o: meshes[o]() for o in list(set(object_names))}
    for p in props:    
        count[p] = []
        mesh_result[p] = []
        for i,n in enumerate(object_names):
            mesh_result[p].append(meshes[n][p])
            count[p].append(meshes[n][p].shape[0]) 
        mesh_result[p] = torch.concat(mesh_result[p],dim=0).cuda()
        count[p] = torch.tensor(count[p])
        start_index[p] = torch.roll(torch.cumsum(count[p],0),shifts=1,dims=0)
        start_index[p][0]=0
    mesh_result['tex'] = torch.stack([meshes[n]['tex'] for n in object_names])
    mesh_result['ranges'] = torch.vstack((start_index['pos_idx'],count['pos_idx'])).T #shape = num of trianges x 2
    for i in range(len(object_names)):
        s = start_index['pos_idx'][i]
        e = s + count['pos_idx'][i]
        c = start_index['pos'][i]
        mesh_result['pos_idx'][s:e,:]+=c
    return mesh_result,count,start_index

image