dendenxu / diff-gaussian-rasterization

Improved 3DGS rasterizer.
Other
90 stars 4 forks source link

Question about culling #1

Closed shijialew closed 5 months ago

shijialew commented 6 months ago

In my understanding, the full projection matrix generated by 3dgs repo converts z to [0,1] instead of [-1,1] which is defined in the ndc space. Here you implement culling as return (p_proj.z > -1 - padding) && (p_proj.z < 1 + padding) && (p_proj.x > -1 - xy_padding) && (p_proj.x < 1. + xy_padding) && (p_proj.y > -1 - xy_padding) && (p_proj.y < 1. + xy_padding); And I guess you do culling in the ndc space, but it may not correspond to the projection of 3dgs. I mean is it more resonable to change p_proj.z > -1 - padding to p_proj.z>-padding? Maybe I have some misunderstandings of your implemention.

dendenxu commented 6 months ago

Ah, your analysis looks solid! I didn't fully check the original 3dgs's projection matrix, only assumed it should be a "flipped OpenGL". Guess aside from the flipping, the ndc definition is also a little bit different from traditional OpenGL. In my usage, I convert the camera parameters from OpenCV to "flipped-OpenGL" as follows:

def convert_to_gaussian_camera(K: torch.Tensor,
                               R: torch.Tensor,
                               T: torch.Tensor,
                               H: torch.Tensor,
                               W: torch.Tensor,
                               n: torch.Tensor,
                               f: torch.Tensor,
                               cpu_K: torch.Tensor,
                               cpu_R: torch.Tensor,
                               cpu_T: torch.Tensor,
                               cpu_H: int,
                               cpu_W: int,
                               cpu_n: float = 0.01,
                               cpu_f: float = 100.,
                               ):
    output = dotdict()

    output.image_height = cpu_H
    output.image_width = cpu_W

    FoVx = focal2fov(cpu_K[0, 0].cpu(), cpu_W.cpu())  # MARK: MIGHT SYNC IN DIST TRAINING, WHY?
    FoVy = focal2fov(cpu_K[1, 1].cpu(), cpu_H.cpu())  # MARK: MIGHT SYNC IN DIST TRAINING, WHY?

    # Use .float() to avoid AMP issues
    output.world_view_transform = getWorld2View(R, T).transpose(0, 1).float()  # this is now to be right multiplied
    output.projection_matrix = getProjectionMatrix(K, H, W, n, f).transpose(0, 1).float()  # this is now to be right multiplied
    output.full_proj_transform = torch.matmul(output.world_view_transform, output.projection_matrix).float()   # 4, 4
    output.camera_center = (-R.mT @ T)[..., 0].float()  # B, 3, 1 -> 3,

    # Set up rasterization configuration
    output.tanfovx = np.tan(FoVx * 0.5)
    output.tanfovy = np.tan(FoVy * 0.5)

    return output

def convert_to_cpu_gaussian_camera(K: torch.Tensor,
                                   R: torch.Tensor,
                                   T: torch.Tensor,
                                   H: torch.Tensor,
                                   W: torch.Tensor,
                                   n: torch.Tensor,
                                   f: torch.Tensor,
                                   cpu_K: torch.Tensor,
                                   cpu_R: torch.Tensor,
                                   cpu_T: torch.Tensor,
                                   cpu_H: int,
                                   cpu_W: int,
                                   cpu_n: float = 0.01,
                                   cpu_f: float = 100.,
                                   ):
    output = dotdict()

    output.image_height = cpu_H
    output.image_width = cpu_W

    FoVx = focal2fov(cpu_K[0, 0].cpu(), cpu_W.cpu())  # MARK: MIGHT SYNC IN DIST TRAINING, WHY?
    FoVy = focal2fov(cpu_K[1, 1].cpu(), cpu_H.cpu())  # MARK: MIGHT SYNC IN DIST TRAINING, WHY?

    # Use .float() to avoid AMP issues
    output.world_view_transform = getWorld2View(cpu_R, cpu_T).transpose(0, 1).float()  # this is now to be right multiplied
    output.projection_matrix = getProjectionMatrix(cpu_K, cpu_H, cpu_W, cpu_n, cpu_f).transpose(0, 1).float()  # this is now to be right multiplied
    output.full_proj_transform = torch.matmul(output.world_view_transform, output.projection_matrix).float()   # 4, 4
    output.camera_center = (-cpu_R.mT @ cpu_T)[..., 0].float()  # B, 3, 1 -> 3,

    # Set up rasterization configuration
    output.tanfovx = np.tan(FoVx * 0.5)
    output.tanfovy = np.tan(FoVy * 0.5)

    return output

def convert_to_gl_camera(K: torch.Tensor,
                         R: torch.Tensor,
                         T: torch.Tensor,
                         H: torch.Tensor,
                         W: torch.Tensor,
                         n: torch.Tensor,
                         f: torch.Tensor,
                         cpu_K: torch.Tensor,
                         cpu_R: torch.Tensor,
                         cpu_T: torch.Tensor,
                         cpu_H: int,
                         cpu_W: int,
                         cpu_n: float = 0.01,
                         cpu_f: float = 100.,
                         ):
    output = dotdict()

    output.image_height = cpu_H
    output.image_width = cpu_W

    output.K = K
    output.R = R
    output.T = T

    output.znear = cpu_n
    output.zfar = cpu_f

    output.FoVx = focal2fov(cpu_K[0, 0].cpu(), cpu_W.cpu())  # MARK: MIGHT SYNC IN DIST TRAINING, WHY?
    output.FoVy = focal2fov(cpu_K[1, 1].cpu(), cpu_H.cpu())  # MARK: MIGHT SYNC IN DIST TRAINING, WHY?

    # Use .float() to avoid AMP issues
    output.world_view_transform = getWorld2View(R, T).transpose(0, 1).float()  # this is now to be right multiplied
    c2w = affine_inverse(output.world_view_transform.mT).mT
    c2w[1] *= -1
    c2w[2] *= -1
    output.world_view_transform = affine_inverse(c2w.mT).mT
    output.projection_matrix = getProjectionMatrix(K, H, W, n, f).transpose(0, 1).float()  # this is now to be right multiplied
    output.projection_matrix[2][0] *= -1
    output.projection_matrix[2][2] *= -1
    output.projection_matrix[2][3] *= -1
    output.full_proj_transform = torch.matmul(output.world_view_transform, output.projection_matrix).float()   # 4, 4
    output.camera_center = (-R.mT @ T)[..., 0].float()  # B, 3, 1 -> 3,

    # Set up rasterization configuration
    output.tanfovx = np.tan(output.FoVx * 0.5)
    output.tanfovy = np.tan(output.FoVy * 0.5)

    return output

I will add this as a reference in the code and clarify a bit in the documentation. This does kind of break the usage of near-far culling... maybe we can add a camera_type interface and set the default to 3dgs? A PR is also welcomed.