nerfstudio-project / nerfacc

A General NeRF Acceleration Toolbox in PyTorch.
https://www.nerfacc.com/
Other
1.37k stars 113 forks source link

Wrong behaviour of traverse_grid function #201

Closed arterms closed 1 year ago

arterms commented 1 year ago

Great job! I'm using traverse_grids function for sampling points along the ray in the range between near_planes and far_planes. I have realized, that function samples points from near_planes till the end of cell corresponding to far_planes. In case of degenerated grid of size (1, 1, 1, 1) it results in large amount of unnecessary points sampled. I have found out that the problem might be fixed by clamping estimated tdist with given tmax in setup_traversal function from _nerfacc.cuda.csrc.include.utilsgrid.cuh :

 tdist = make_float3(
        (ray.dir.x == 0.0f) ? tmax : min(tmax_xyz.x, tmax),
        (ray.dir.y == 0.0f) ? tmax : min(tmax_xyz.y, tmax),
        (ray.dir.z == 0.0f) ? tmax : min(tmax_xyz.z, tmax)
    );

Is there any other way to fix the problem?

liruilong940607 commented 1 year ago

Hi it wouldn't result unnecessary sampled points as the ray would terminate at the bbox of the grid. (Actually it start from the max(near_plane, grid intersection) and terminates at min(far_plane, grid intersection). See here in code.

A code snippet for testing the case you describe:

>>> rays_o = torch.tensor([[-1., 0., 0.]], device="cuda:0")
>>> rays_d = torch.tensor([[1., 0., 0.]], device="cuda:0")
>>> binaries = torch.ones((1, 1, 1, 1), dtype=torch.bool, device="cuda:0")
>>> aabbs = torch.tensor([[0., 0., 0., 1., 1., 1.]], device="cuda:0")
>>> intervals, samples = nerfacc.traverse_grids(rays_o, rays_d, binaries, aabbs, step_size=0.2)
>>> intervals.vals
tensor([1.0000, 1.2000, 1.4000, 1.6000, 1.8000, 2.0000], device='cuda:0')
arterms commented 1 year ago

Hi! Please find below the test which fails when the ray direction is not parallel to the axes:

import pytest

import torch

device = "cuda:0"

@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_traverse_grids():

    from nerfacc.grid import traverse_grids

    rays_o = torch.tensor([[-1., 0., 0.]], device="cuda:0")
    rays_d = torch.tensor([[1., 0.01, 0.01]], device="cuda:0")
    rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)

    binaries = torch.ones((1, 1, 1, 1), dtype=torch.bool, device="cuda:0")
    aabbs = torch.tensor([[0., 0., 0., 1., 1., 1.]], device="cuda:0")

    near_planes = torch.tensor([1.2], device=device)
    far_planes = torch.tensor([1.5], device=device)
    eps = 1e-7

    intervals, samples = traverse_grids(rays_o, rays_d, binaries, aabbs, step_size=0.1, near_planes=near_planes, far_planes=far_planes)
    assert (intervals.vals >= (near_planes - eps)).all()
    assert (intervals.vals <= (far_planes + eps)).all()

if __name__ == "__main__":
    test_traverse_grids()

I looked through the code. If I clearly understand, the variable this_tmax is used in setup_traversal function to estimate tdist. The variable tdist defines t_traverse used for early stopping while marching along the ray. But from the implementation of setup_traversal function I can see that the variable this_tmax passed into the function as argument tmax is involved only into the estimation of the final index of the grid cell to stop ray marching, not the maximum distance along the ray. So points will be sampled till the border of the grid cell corresponding to far_planes value.

liruilong940607 commented 1 year ago

Oh yeah I got what you mean and you are right. It would come across the far plane to finish the last grid cell.

Actually the t_traverse is the current marching distance at each step. So a simple fix is to break the marching when t_traverse exceed far plane.

Feel free to open a PR if you want. Or I can get it fixed during the weekend