nerfstudio-project / nerfstudio

A collaboration friendly studio for NeRFs
https://docs.nerf.studio
Apache License 2.0
9.51k stars 1.3k forks source link

PyTorch HashGrid implementation #1200

Open stefan-baumann opened 1 year ago

stefan-baumann commented 1 year ago

Hi, is there a reason why you tend to not use your own implementation of hashgrid encoding? Everywhere I look in the codebase you always seem to use the tcnn implementation directly instead of your wrapper. Is your own implementation not considered finished & ready to use yet? Also, have you done any tests w.r.t. the consistency of the results between the tcnn & custom PyTorch backends? Taking a quick look at the code, I couldn't find the offset used by the tcnn implementation in your version: https://github.com/NVlabs/tiny-cuda-nn/blob/ae6361103d9b18ddfee74668f5ba3d80410df2ac/include/tiny-cuda-nn/common_device.h#L428

tancik commented 1 year ago

tcnn is used everywhere since it is much faster. At some point I think we should dig into the pytorch implementation to 1. confirm it is accurate, and 2. see if we can get it running faster. You are correct that the offset is currently missing.

stefan-baumann commented 1 year ago

Thanks for the quick response! Regarding the question about tcnn, I was primarily interested as to why you seem to always use the tcnn version directly instead of your own class, which wraps tcnn by default. As the wrapper is quite minimal, I wouldn't expect the singular additional function call to cause any noticeable slowdowns. If you were to use the wrapper instead, the automatic fallback on the custom PyTorch implementation could be used if tcnn is not available, which I assume was the whole point of implementing this class in the first place?

tancik commented 1 year ago

Yes, that is indeed the purpose of the class. The issue is that when using an MLP + Hashgrid, it is fastest to use tcnn.NetworkWithInputEncoding then tcnn.Encoding followed by tcnn.Network. Therefore our hashencoding + tcnn.Network wouldn't be as fast. The solution is to create a NetworkWithInputEncoding that operates similar to how we created out pytorch hash encoding. I just haven't had a chance to do that yet.