lab-cosmo / sphericart

Multi-language library for the calculation of spherical harmonics in Cartesian coordinates
https://sphericart.readthedocs.io/en/latest/
MIT License
73 stars 13 forks source link

Remove weird `sphericart.torch` mechanics for double backward #139

Open frostedoyster opened 3 months ago

frostedoyster commented 3 months ago

In torch, there is no way to know if a second derivative call might be executed by the user (unlike for the first derivative, where requires_grad can be checked). As a result, in the current API, we require the user to specify if second derivatives will be used at class initialization. This is pretty useless for two reasons:

The only way I see to make this feature usable is to calculate the second derivatives on the fly when their calculation is needed. This will recompute the values and first derivatives of the spherical harmonics, but the current approach which avoids the recomputation is unsustainable in practice. We should also find a way to mark the second derivative function as non-differentiable to avoid, once again, silent failures if people try to differentiate 3 or more times. Something similar to @once_differentiable (https://discuss.pytorch.org/t/what-does-the-function-wrapper-once-differentiable-do/31513), but for C++ torch.

ceriottm commented 3 months ago

Maybe having a separate class for second derivatives?

frostedoyster commented 1 month ago

This is partially fixed thanks to warnings. Once we can reasonably expect torch 2.4 as a minimum requirement, we will be able to fix it once and for all with https://pytorch.org/tutorials/advanced/cpp_custom_ops.html

Luthaf commented 1 month ago

I'm not sure I see how the link you shared would fix the issue? By allowing to define backward Python-side?