dendenxu / fast-gaussian-rasterization

A geometry-shader-based, global CUDA sorted high-performance 3D Gaussian Splatting rasterizer. Can achieve a 5-10x speedup in rendering compared to the vanialla diff-gaussian-rasterization.
Other
219 stars 2 forks source link

use fast-gaussian-rasterization in gaussian-splatting #4

Open liyipeng137 opened 1 month ago

liyipeng137 commented 1 month ago

Great works! Can I use it directly in gaussian-splatting?I tried but failed😿. Maybe I need to modify the training part of the code

dendenxu commented 1 month ago

Hi, sorry for the delayed reply. Yes, you should be able to use it for any pre-trained model for 3DGS. However, there's a known issue with the mismatched culling function of ours v.s. the original, as mentioned here. I plan on matching up the culling function lately. Or you can simply replace the OpenCV to OpenGL camera parameter conversion function with ours that maps the depth to [-1, 1].

def convert_to_gaussian_camera(K: torch.Tensor,
                               R: torch.Tensor,
                               T: torch.Tensor,
                               H: torch.Tensor,
                               W: torch.Tensor,
                               n: torch.Tensor,
                               f: torch.Tensor,
                               cpu_K: torch.Tensor,
                               cpu_R: torch.Tensor,
                               cpu_T: torch.Tensor,
                               cpu_H: int,
                               cpu_W: int,
                               cpu_n: float = 0.01,
                               cpu_f: float = 100.,
                               ):
    output = dotdict()

    output.image_height = cpu_H
    output.image_width = cpu_W

    FoVx = focal2fov(cpu_K[0, 0].cpu(), cpu_W.cpu())  # MARK: MIGHT SYNC IN DIST TRAINING, WHY?
    FoVy = focal2fov(cpu_K[1, 1].cpu(), cpu_H.cpu())  # MARK: MIGHT SYNC IN DIST TRAINING, WHY?

    # Use .float() to avoid AMP issues
    output.world_view_transform = getWorld2View(R, T).transpose(0, 1).float()  # this is now to be right multiplied
    output.projection_matrix = getProjectionMatrix(K, H, W, n, f).transpose(0, 1).float()  # this is now to be right multiplied
    output.full_proj_transform = torch.matmul(output.world_view_transform, output.projection_matrix).float()   # 4, 4
    output.camera_center = (-R.mT @ T)[..., 0].float()  # B, 3, 1 -> 3,

    # Set up rasterization configuration
    output.tanfovx = np.tan(FoVx * 0.5)
    output.tanfovy = np.tan(FoVy * 0.5)

    return output

def convert_to_cpu_gaussian_camera(K: torch.Tensor,
                                   R: torch.Tensor,
                                   T: torch.Tensor,
                                   H: torch.Tensor,
                                   W: torch.Tensor,
                                   n: torch.Tensor,
                                   f: torch.Tensor,
                                   cpu_K: torch.Tensor,
                                   cpu_R: torch.Tensor,
                                   cpu_T: torch.Tensor,
                                   cpu_H: int,
                                   cpu_W: int,
                                   cpu_n: float = 0.01,
                                   cpu_f: float = 100.,
                                   ):
    output = dotdict()

    output.image_height = cpu_H
    output.image_width = cpu_W

    FoVx = focal2fov(cpu_K[0, 0].cpu(), cpu_W.cpu())  # MARK: MIGHT SYNC IN DIST TRAINING, WHY?
    FoVy = focal2fov(cpu_K[1, 1].cpu(), cpu_H.cpu())  # MARK: MIGHT SYNC IN DIST TRAINING, WHY?

    # Use .float() to avoid AMP issues
    output.world_view_transform = getWorld2View(cpu_R, cpu_T).transpose(0, 1).float()  # this is now to be right multiplied
    output.projection_matrix = getProjectionMatrix(cpu_K, cpu_H, cpu_W, cpu_n, cpu_f).transpose(0, 1).float()  # this is now to be right multiplied
    output.full_proj_transform = torch.matmul(output.world_view_transform, output.projection_matrix).float()   # 4, 4
    output.camera_center = (-cpu_R.mT @ cpu_T)[..., 0].float()  # B, 3, 1 -> 3,

    # Set up rasterization configuration
    output.tanfovx = np.tan(FoVx * 0.5)
    output.tanfovy = np.tan(FoVy * 0.5)

    return output
liyipeng137 commented 1 month ago

Thanks reply! I will try this