GradientSpaces / LoopSplat

[3DV 2025] LoopSplat: Loop Closure by Registering 3D Gaussian Splats
https://loopsplat.github.io/
MIT License
264 stars 12 forks source link

I ran your code and saw a fast rasterization initialization, but I couldn't find where the initialization parameters are being passed. Could you tell me where these packages are? #13

Closed enterfutures closed 1 week ago

enterfutures commented 2 weeks ago

from typing import NamedTuple import torch.nn as nn import torch from . import _C

def cpu_deep_copy_tuple(input_tuple): copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple] return tuple(copied_tensors)

def rasterize_gaussians( means3D, means2D, sh, colors_precomp, opacities, scales, rotations, cov3Ds_precomp, theta, rho, raster_settings, ): return _RasterizeGaussians.apply( means3D, means2D, sh, colors_precomp, opacities, scales, rotations, cov3Ds_precomp, theta, rho, raster_settings, )

class _RasterizeGaussians(torch.autograd.Function): @staticmethod def forward( ctx, means3D, means2D, sh, colors_precomp, opacities, scales, rotations, cov3Ds_precomp, theta, rho, raster_settings, ):

    # Restructure arguments the way that the C++ lib expects them
    args = (
        raster_settings.bg,
        means3D,
        colors_precomp,
        opacities,
        scales,
        rotations,
        raster_settings.scale_modifier,
        cov3Ds_precomp,
        raster_settings.viewmatrix,
        raster_settings.projmatrix,
        raster_settings.projmatrix_raw,
        raster_settings.tanfovx,
        raster_settings.tanfovy,
        raster_settings.image_height,
        raster_settings.image_width,
        sh,
        raster_settings.sh_degree,
        raster_settings.campos,
        raster_settings.prefiltered,
        raster_settings.debug,
    )

    # Invoke C++/CUDA rasterizer
    if raster_settings.debug:
        cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
        try:
            num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer, depth, opacity, n_touched = _C.rasterize_gaussians(*args)
        except Exception as ex:
            torch.save(cpu_args, "snapshot_fw.dump")
            print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
            raise ex
    else:
        num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer, depth, opacity, n_touched = _C.rasterize_gaussians(*args)

    # Keep relevant tensors for backward
    ctx.raster_settings = raster_settings
    ctx.num_rendered = num_rendered
    ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
    return color, radii, depth, opacity, n_touched

@staticmethod
def backward(ctx, grad_out_color, grad_out_radii, grad_out_depth, grad_out_opacity, grad_n_touched):

    # Restore necessary values from context
    num_rendered = ctx.num_rendered
    raster_settings = ctx.raster_settings
    colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors

    # Restructure args as C++ method expects them
    args = (raster_settings.bg,
            means3D,
            radii,
            colors_precomp,
            scales,
            rotations,
            raster_settings.scale_modifier,
            cov3Ds_precomp,
            raster_settings.viewmatrix,
            raster_settings.projmatrix,
            raster_settings.projmatrix_raw,
            raster_settings.tanfovx,
            raster_settings.tanfovy,
            grad_out_color,
            grad_out_depth,
            sh,
            raster_settings.sh_degree,
            raster_settings.campos,
            geomBuffer,
            num_rendered,
            binningBuffer,
            imgBuffer,
            raster_settings.debug)

    # Compute gradients for relevant tensors by invoking backward method
    if raster_settings.debug:
        cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
        try:
            grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations, grad_tau = _C.rasterize_gaussians_backward(*args)
        except Exception as ex:
            torch.save(cpu_args, "snapshot_bw.dump")
            print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
            raise ex
    else:
         grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations, grad_tau = _C.rasterize_gaussians_backward(*args)

    grad_tau = torch.sum(grad_tau.view(-1, 6), dim=0)
    grad_rho = grad_tau[:3].view(1, -1)
    grad_theta = grad_tau[3:].view(1, -1)

    grads = (
        grad_means3D,
        grad_means2D,
        grad_sh,
        grad_colors_precomp,
        grad_opacities,
        grad_scales,
        grad_rotations,
        grad_cov3Ds_precomp,
        grad_theta,
        grad_rho,
        None,
    )

    return grads

class GaussianRasterizationSettings(NamedTuple): image_height: int image_width: int tanfovx : float tanfovy : float bg : torch.Tensor scale_modifier : float viewmatrix : torch.Tensor projmatrix : torch.Tensor projmatrix_raw : torch.Tensor sh_degree : int campos : torch.Tensor prefiltered : bool debug : bool

class GaussianRasterizer(nn.Module): def init(self, raster_settings): super().init() self.raster_settings = raster_settings

def markVisible(self, positions):
    # Mark visible points (based on frustum culling for camera) with a boolean 
    with torch.no_grad():
        raster_settings = self.raster_settings
        visible = _C.mark_visible(
            positions,
            raster_settings.viewmatrix,
            raster_settings.projmatrix)

    return visible

def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None, theta=None, rho=None):

    raster_settings = self.raster_settings

    if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
        raise Exception('Please provide excatly one of either SHs or precomputed colors!')

    if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
        raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')

    if shs is None:
        shs = torch.Tensor([])
    if colors_precomp is None:
        colors_precomp = torch.Tensor([])

    if scales is None:
        scales = torch.Tensor([])
    if rotations is None:
        rotations = torch.Tensor([])
    if cov3D_precomp is None:
        cov3D_precomp = torch.Tensor([])
    if theta is None:
        theta = torch.Tensor([])
    if rho is None:
        rho = torch.Tensor([])

    # Invoke C++/CUDA rasterization routine
    return rasterize_gaussians(
        means3D,
        means2D,
        shs,
        colors_precomp,
        opacities,
        scales, 
        rotations,
        cov3D_precomp,
        theta,
        rho,
        raster_settings
    ),where this forward  and  backward?