brentyi / tilted

Canonical Factors for Hybrid Neural Fields @ ICCV 2023
101 stars 2 forks source link

Why jax? #1

Closed hiyyg closed 1 year ago

hiyyg commented 1 year ago

Thanks for the great work! I want to ask a question irrelevant to the paper:

why did you choose jax? Is it faster and easier to use than pytorch?

brentyi commented 1 year ago

The answer here is not super exciting; I just enjoy working with JAX much more than I enjoy working with PyTorch. Most of my other projects are in PyTorch and I personally find JAX dramatically simpler to use, although this is a contentious topic. 🙂

As far as speed goes, my experience has been generally very positive—being forced to JIT everything results in generally faster code with good CPU/GPU parallelism, and things like the multi-GPU parallelism in the visualization code are easier implement.

That said, for NeRF stuff specifically I think the PyTorch ecosystem has large advantages. A lot of CUDA tools like tiny-cuda-nn and nerfacc are more readily available. Some NeRF codebases use boolean masking, which relies of dynamic shapes that are possible in PyTorch but not in JAX, and I do sometimes run into obscure JIT-related issues (example: https://github.com/google/jax/discussions/10332).