sarafridov / plenoxels

JAX implementation of Plenoxels
111 stars 24 forks source link

About the speed for jax plenoxels #5

Closed Fangkang515 closed 2 years ago

Fangkang515 commented 2 years ago

Thanks for your astonishing work.

“but is much slower (roughly 1 hour per epoch) than the CUDA implementation https://github.com/sxyu/svox2 (roughly 1 minute per epoch)”

As mentioned above, Jax is much slower than CUDA implementation. As far as I know, Jax could be also accelerated by GPU(cuda). Why is the speed gap so large? And is the fast training of plenoxels due to the acceleration of cuda implementation instead of not using the neural network?

sarafridov commented 2 years ago

The JAX code (in its current state) is not optimized for speed; it could likely be sped up substantially by further engineering (in JAX), and indeed this is one of our TODOs going forward. While developing Plenoxels, we often used this higher-level version to try out ideas, and then optimized the final version in CUDA for speed. Certainly the speed of the CUDA version is in part due to hand-engineering the implementation (and even that version could probably be accelerated further), but this hand-engineering speedup was only really possible because of the simplicity of the method. Note also that most automatic differentiation libraries are already optimized for neural networks, so while you could probably get some benefit from re-implementing something like original (TensorFlow) NeRF in CUDA, I'm not sure how much the speedup would be.