nerfstudio-project / nerfacc

A General NeRF Acceleration Toolbox in PyTorch.
https://www.nerfacc.com/
Other
1.37k stars 113 forks source link

Tip for eikonal loss calculation #215

Closed lzhnb closed 1 year ago

lzhnb commented 1 year ago

Thanks for your amazing project!

eiknoal_loss is an important term for SDF fields (e.g. NeuS / VolSDF). For the early version (like v3.5), nerfacc operated the process in a fine-grined manner (ray marching sampling/field forward/volume rendering manually). In this version (v5.2), nerfacc redesigned the APIs and packed these process in a pipeline consisted of estimator and rendering. The estimator design is really fantastic, but the rendering means that I could only consider my field (or a forward function) as an sigma&rgb generator. In this way, once I want to get the points to calculate eikonal_loss, I need to record the interval, recover the sample points and feed them into the network to forward again.

Can the future version provide a more convient and flexible APIs? Here is just a friendly tip. :) And the following is the example that I modify render_image_with_occgrid in utils.py from your examples to support the eikonal_loss calculation:

def render_image_with_occgrid_ek(
    # scene
    radiance_field: torch.nn.Module,
    estimator: OccGridEstimator,
    rays: Rays,
    # rendering options
    near_plane: float = 0.0,
    far_plane: float = 1e10,
    render_step_size: float = 1e-3,
    render_bkgd: Optional[torch.Tensor] = None,
    cone_angle: float = 0.0,
    alpha_thre: float = 0.0,
    # test options
    test_chunk_size: int = 8192,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
    """Render the pixels of an image."""
    rays_shape = rays.origins.shape
    if len(rays_shape) == 3:
        height, width, _ = rays_shape
        num_rays = height * width
        rays = namedtuple_map(lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays)
    else:
        num_rays, _ = rays_shape

    def sigma_fn(
        t_starts: torch.Tensor, t_ends: torch.Tensor, ray_indices: torch.Tensor
    ) -> torch.Tensor:
        t_origins = chunk_rays.origins[ray_indices]
        t_dirs = chunk_rays.viewdirs[ray_indices]
        positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
        sigmas = radiance_field.query_density(positions)
        return sigmas.squeeze(-1)

    def rgb_sigma_fn(
        t_starts: torch.Tensor, t_ends: torch.Tensor, ray_indices: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        t_origins = chunk_rays.origins[ray_indices]
        t_dirs = chunk_rays.viewdirs[ray_indices]
        positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
        results = radiance_field(positions, t_dirs)
        rgbs, sigmas = results["rgbs"], results["sigmas"]
        return rgbs, sigmas.squeeze(-1)

    results = []
    ek_points = []
    chunk = torch.iinfo(torch.int32).max if radiance_field.training else test_chunk_size
    for i in range(0, num_rays, chunk):
        chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
        ray_indices, t_starts, t_ends = estimator.sampling(
            chunk_rays.origins,
            chunk_rays.viewdirs,
            sigma_fn=sigma_fn,  # for visibility sampling
            near_plane=near_plane,
            far_plane=far_plane,
            render_step_size=render_step_size,
            stratified=radiance_field.training,
            cone_angle=cone_angle,
            alpha_thre=alpha_thre,
        )
        rgb, opacity, depth, extras = rendering(
            t_starts,
            t_ends,
            ray_indices,
            n_rays=chunk_rays.origins.shape[0],
            rgb_sigma_fn=rgb_sigma_fn,
            render_bkgd=render_bkgd,
        )
        chunk_results = [rgb, opacity, depth, len(t_starts)]
        results.append(chunk_results)

        # NOTE: get the points for calculating eikonal loss
        t_origins = chunk_rays.origins[ray_indices]
        t_dirs = chunk_rays.viewdirs[ray_indices]
        positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
        ek_points.append(positions)

    colors, opacities, depths, n_rendering_samples = [
        torch.cat(r, dim=0) if isinstance(r[0], torch.Tensor) else r for r in zip(*results)
    ]
    return (
        colors.view((*rays_shape[:-1], -1)),
        opacities.view((*rays_shape[:-1], -1)),
        depths.view((*rays_shape[:-1], -1)),
        torch.cat(ek_points),
        sum(n_rendering_samples),
    )
lzhnb commented 1 year ago

Another way is that I conduct volume rendering via render_weight_from_density and accumulate_along_rays manually. :)