DecaYale / RNNPose

RNNPose: Recurrent 6-DoF Object Pose Refinement with Robust Correspondence Field Estimation and Pose Optimization, CVPR 2022
Apache License 2.0
160 stars 17 forks source link

how to speed up render #7

Closed mate-huaboy closed 2 years ago

mate-huaboy commented 2 years ago

hey,I use your code to render for estimating pose,but I find it too slow to train,my code is followed: ` from pytorch3d.io import IO import torch import torch.nn as nn import sys import os.path as osp cur_dir=osp.abspath(osp.dirname(file)) sys.path.insert(0,cur_dir) from ColorShader import ColorShader from pytorch3d.renderer import ( PerspectiveCameras, RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams, )

import datetime

class DiffRender(nn.Module): def init(self, mesh_path, render_texture=False, render_image_size=(64,64)): super().init()

    # self.mesh = mesh
    if mesh_path.endswith('.ply'):
        self.mesh = IO().load_mesh(mesh_path)

    self.cam_opencv2pytch3d = torch.tensor(
        [[-1, 0, 0, 0],
         [0, -1, 0, 0],
         [0, 0, 1, 0],
         [0, 0, 0, 1]], dtype=torch.float32
    )
    self.cameras = PerspectiveCameras( image_size=[render_image_size], in_ndc=False,)# why not use R and t

    self.raster_settings = RasterizationSettings(
        image_size=render_image_size,
        blur_radius=0.0,
        faces_per_pixel=1,
        bin_size=None,  # 0
        perspective_correct=True
    )
    self.blend_params = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0))
    rasterizer = MeshRasterizer(
         cameras=self.cameras,
        raster_settings=self.raster_settings
    )
    self.renderer = MeshRenderer(
        rasterizer,
        shader=ColorShader(blend_params=self.blend_params)
    )

def to(self, *args, **kwargs):
    if 'device' in kwargs.keys():
        device = kwargs['device']
    else:
        device = args[0]
    super().to(device)

    self.mesh = self.mesh.to(device)
    self.cam_opencv2pytch3d=self.cam_opencv2pytch3d.to(device)
    self.renderer=self.renderer.to(device)

    return self

def forward(self, T, K, render_image_size, near=0.1, far=6, render_texture=None, mode='bilinear'):
    """
    Args:

        T: (B,3,4) or (B,4,4)
        K: (B,3,3)
        render_image_size (tuple): (h,w)
        near (float, optional):  Defaults to 0.1.
        far (int, optional): Defaults to 6.
        But actually here B==1
    """
    start=datetime.datetime.now()
    B = T.shape[0]
    # face_attribute = vert_attribute[self.faces.long()]

    device = T.device
    T=T[...,:3,:3]  #add

    # T = self.cam_opencv2pytch3d[:3,:3].to(device=T.device) @ T
    T = self.cam_opencv2pytch3d[:3,:3] @ T

    ## X_cam = X_world R + t
    R = T[..., :3, :3].transpose(-1, -2)
    # t = T[..., :3, 3]
    t=torch.tensor([0,0,1],device=T.device).reshape(1,3)
    # t = -(R@T[...,:3,3:]).squeeze(-1)
    # start1=datetime.datetime.now()

    cameras = PerspectiveCameras(R=R,T=t,focal_length=torch.stack([K[:, 0, 0], K[:, 1, 1]], dim=-1),
                                 principal_point=K[:, :2, 2], image_size=[render_image_size] * B, in_ndc=False,
                                 device=device)

    s=datetime.datetime.now()
    target_images = self.renderer(self.mesh, cameras=cameras,blendParams=self.blend_params) #1*480*640*4
    e=datetime.datetime.now()
    # print((start2-start1).microseconds)
    print((s-start).microseconds)
    print((e-s).microseconds)
    return target_images[...,:3]

class DiffRenderer_Normal_Wrapper(nn.Module): def init(self, obj_paths, device="cuda", render_texture=False): super().init()

    self.renderers = []
    for obj_path in obj_paths:
        self.renderers.append(
            DiffRender(obj_path, render_texture).to(device=device)
        )

    self.renderers = nn.ModuleList(self.renderers)
def forward(self, model_idx, R, K, render_image_size, near=0.1, far=6, render_tex=False):
    """
    model_idx:(24,)
    R:(24,3,3)
    K:(24,3,3)
    """
    color_outputs = []
    for b, _ in enumerate(model_idx):#model_idx have 24 elements between 0 and 12 means the model_id,such as [1,2,3,2,5,8,...]
        # model_idx = self.cls2idx[model_names[b]]
        color = self.renderers[model_idx[b]](R[b:b + 1], K[b:b + 1], render_image_size,
                                                  near, far, render_texture=render_tex)

        color_outputs.append(color)
    return color_outputs`

In my program, I called the DiffRenderer_Normal_Wrapper class externally, selected 24 of the 13 classes (repeatable) at a time and estimated their corresponding rotations, then rendered the 24 images using the DiffRenderer_Normal_Wrapper class. However, I found that this process was very slow, which took about a second, and this process became the bottleneck of the training process. However, I found that if two adjacent frames of images belong to the same model, the training process would be faster. It seems that the principle of locality can be used to speed up the training process, but I am not sure how to do it. Can anyone give me some advice to speed up the rendering process? Thank you very much!

DecaYale commented 2 years ago

We implement the function you mentioned with pytorch3d. Rendering 24 images at a time could be time-consuming with the current pytorch3d implementation. But our training only needs to render one image per GPU and thus our training is more efficient. Your issue seems to be less relevant to our project. I suggest you raise this issue in the pytorch3d's forum. I may close this issue temporarily

mate-huaboy commented 2 years ago

We implement the function you mentioned with pytorch3d. Rendering 24 images at a time could be time-consuming with the current pytorch3d implementation. But our training only needs to render one image per GPU and thus our training is more efficient. Your issue seems to be less relevant to our project. I suggest you raise this issue in the pytorch3d's forum. I may close this issue temporarily

thank you for your reply,i have already solved this problem.