NVlabs / nvdiffrast

Nvdiffrast - Modular Primitives for High-Performance Differentiable Rendering
Other
1.35k stars 144 forks source link

nvdiffrast + Jax? #38

Closed Daiver closed 3 years ago

Daiver commented 3 years ago

Hi! I have a theoretical question. Nvdiffrast currently supports PyTorch and Tensorflow, and as far as i can see a lot of code are reused for both versions. How many effort is needed to add support of Jax? I understand that jax support is outside nvdiffrast's dev team plans, just wondering. Please let me know if github issues is not appropriate place for such questions.

s-laine commented 3 years ago

All of the Cuda code that does the heavy lifting, and the C++ used for OpenGL stuff, is indeed shared between the frameworks. I don't know anything about Jax, but if it operates on tensors stored as dense arrays in GPU memory like the other frameworks, then the support would most likely involve writing similar interface layers (in C++ and Python) as we now have for TensorFlow and PyTorch separately. These layers are a decent amount of code, but it is fairly straightforward in at least both of the frameworks we support now.

But it is true that supporting Jax or other new frameworks is not on our roadmap right now. It appears that PyTorch is overwhelmingly the more popular framework, and TensorFlow support has gotten somewhat patchy already; some new features have been implemented only on PyTorch side, and nobody has complained. The next change related to framework support, whenever that may be, will more likely be dropping TensorFlow support altogether rather than adding new frameworks.

Daiver commented 3 years ago

Thank you for such detailed explanation!