blurgyy / jaxngp

JAX implementation of instant-ngp (NeRF part)
Apache License 2.0
28 stars 3 forks source link

where is shjax ? #6

Open Chutlhu opened 8 months ago

Chutlhu commented 8 months ago

Dear author,

thank you very much for this repository. I am interested in the spherical harmonics hash encoding, could you provide some more information about the shjax library? I cannot find it online

thank you very much

blurgyy commented 8 months ago

Hi @Chutlhu,

Thank you for your interest.

I implemented shjax as a custom extension in deps/spherical-harmonics-encoding-jax/, and it is integrated in python via

https://github.com/blurgyy/jaxngp/blob/d63c2c9f30b5e77c0ea212ceab94d62529f7a887/models/encoders.py#L351-L357

You can inspect its source there.

However, I ended up using a JAX implementation of the spherical harmonics encoding, because in my benchmarks, the JAX implementation is consistently faster than the custom CUDA implementation, I think it is because JAX code can be easier optimized via operations like kernel fusion. The JAX implementation which is used throughout the project can be found at https://github.com/blurgyy/jaxngp/blob/d63c2c9f30b5e77c0ea212ceab94d62529f7a887/models/encoders.py#L361

The benchmark I used to compare between the JAX vs CUDA implementation of spherical harmonics encoding is at https://github.com/blurgyy/jaxngp/blob/d63c2c9f30b5e77c0ea212ceab94d62529f7a887/models/encoders.py#L467

Cheers!

Chutlhu commented 8 months ago

Dear @blurgyy , Thank you very much! I found it. I played a little bit with it. Using this positional embedding, the model overfits the training data very well (compared to standard Random Fourier Features), but it seems to lose the native interpolation property. Do you know something about it? Do you have some references about this problem?