nerfstudio-project / nerfacc

A General NeRF Acceleration Toolbox in PyTorch.
https://www.nerfacc.com/
Other
1.37k stars 113 forks source link

Accelerate Instant-NGP inference #197

Closed Linyou closed 1 year ago

Linyou commented 1 year ago

This PR enhances Nerfacc's Instant-NGP inference performance by implementing the following API changes:

  1. The traverse_grids function has been modified to support both train and test modes.
  2. A new function called mark_invisible_cells has been added to the occ_grid module in order to prevent rendering artifacts in unseen spaces.
liruilong940607 commented 1 year ago

Thanks for implementing this! It is pretty nice to have that support

liruilong940607 commented 1 year ago

On the high level, the two ways of ray marching are pretty similar to each other: The "train" way of marching is to take N rays and march all the steps for each ray. The "test" ways of marching is to take all rays and march N steps for each ray (and iterative).

I feel it should be not that hard to unify the API (as well as the implementation) for the two.

To be more concrete, implementation differences between the two are:

So maybe we can unify them into the same "traverse_grid" function, with extra arguments (max_per_ray_samples=inf, masks=None, t_sorted=None, t_indices=None, hits=None). And for "traverse_grid_test", your can just call that function with an updated "near_planes" at every iteration of marching.

In this case, I think it makes sense to let the CUDA kernel return an extra tensor (n_rays,) that indicates the termination distance during grid traversal, which is essentially the "near_planes" for the next iteration of "traverse_grid_test". (the near_planes you are returning has a confusing name, which I think it should be termination_planes or something like that.

Linyou commented 1 year ago

Sound good! I think we could unify the API using extra arguments "(max_per_ray_samples=inf, masks=None, t_sorted=None, t_indices=None, hits=None)". Nice idea, BTW!

We also need to unify the return values. I suggest using the data structure defined in "data_spect.h" to store t_start and t_end, instead of creating torch::Tensor directly as I am currently doing. It may be helpful to add new methods for allocating memory in "data_spect.h" specifically for t_start and `t_end", since they are pre-allocated in the "test" way. What do you think?

As for near_planes, I already tweaked the code so we don't need to return it, we can just update it inside the "traverse_grid" kernel.

liruilong940607 commented 1 year ago

We also need to unify the return values. I suggest using the data structure defined in "data_spect.h" to store t_start and t_end, instead of creating torch::Tensor directly as I am currently doing. It may be helpful to add new methods for allocating memory in "data_spect.h" specifically for t_start and `t_end", since they are pre-allocated in the "test" way. What do you think?

I think you can use the RaySegmentsSpec just like what is being used in the traverse_grid function. And you can get t_starts and t_ends by:

https://github.com/KAIR-BAIR/nerfacc/blob/8340e19daad4bafe24125150a8c56161838086fa/tests/test_grid.py#L60-L61

As for near_planes, I already tweaked the code so we don't need to return it, we can just update it inside the "traverse_grid" kernel.

Do you mean that you inplace change the value of it? I would suggest against doing inplace modification as it is not quite user-friendly.

Linyou commented 1 year ago

I have unified the "traverse_grid" API, and now both "train" and "test" can use the same Python function. On the low level, we still need to call separate C functions to launch the CUDA kernel.

Note that the "traverse_grid" function now returns three objects (intervals, samples, termination_planes), and "termination_planes" will be just None when "ray_mask_id" is not provided.

liruilong940607 commented 1 year ago

@Linyou The latest commit should resolve the memory concerns we had before. The test is also updated to match with the actual use case. Lmk what do you think.

Linyou commented 1 year ago

Thanks! I believe that the current API design is now highly usable for test mode rendering, thanks to the latest commit.

BTW, after this PR is merged, I will create a new one for ngp test mode rendering in the examples.

liruilong940607 commented 1 year ago

@Linyou I also did some cleanups for mark_invisible_cells() and changed the API a tiny bit (the K). Now I'm happy to merge it if you think it's all good.

liruilong940607 commented 1 year ago

Comments addressed. Ready to Go? @Linyou

Linyou commented 1 year ago

@liruilong940607 Yeah! All good!

liruilong940607 commented 1 year ago

Thanks for the patience!! Shipped!