Closed Fangkang515 closed 2 years ago
The JAX code (in its current state) is not optimized for speed; it could likely be sped up substantially by further engineering (in JAX), and indeed this is one of our TODOs going forward. While developing Plenoxels, we often used this higher-level version to try out ideas, and then optimized the final version in CUDA for speed. Certainly the speed of the CUDA version is in part due to hand-engineering the implementation (and even that version could probably be accelerated further), but this hand-engineering speedup was only really possible because of the simplicity of the method. Note also that most automatic differentiation libraries are already optimized for neural networks, so while you could probably get some benefit from re-implementing something like original (TensorFlow) NeRF in CUDA, I'm not sure how much the speedup would be.
Thanks for your astonishing work.
“but is much slower (roughly 1 hour per epoch) than the CUDA implementation https://github.com/sxyu/svox2 (roughly 1 minute per epoch)”
As mentioned above, Jax is much slower than CUDA implementation. As far as I know, Jax could be also accelerated by GPU(cuda). Why is the speed gap so large? And is the fast training of plenoxels due to the acceleration of cuda implementation instead of not using the neural network?