Open zou3519 opened 2 years ago
Hey! I'm actually trying to implement this. (#334) Is there anyway that I can get involved or ask for this to be accelerated? From my experience, PyTorch calculates gradients a bit faster than Jax (just based on some small experiments that I've ran) and right now I'm trying to build the kernel off of PyTorch gradients. I've succeeded in doing so, but right now it doesn't seem very efficient. There are more details in my issue.
Hey! I'm actually trying to implement this. (#334) Is there anyway that I can get involved or ask for this to be accelerated? From my experience, PyTorch calculates gradients a bit faster than Jax (just based on some small experiments that I've ran) and right now I'm trying to build the kernel off of PyTorch gradients. I've succeeded in doing so, but right now it doesn't seem very efficient. There are more details in my issue.
@mohamad-amin let's continue discussion in the other thread -- we're happy to help debug your use case. I was planning on reading about neural tangents later this week and trying an implementation after that so maybe that will also help
I've brought this up in a few issues both here and on PyTorch, but machine-learning potentials is a big one— to ensure mathematical correctness, we predict energy/energies, but are interested in both them and their derivatives w.r.t. various inputs. For more complicated scenarios this almost always means the sort of advanced gradients (efficient per-sample, Jacobians, Hessians) that functorch
is working on.
We want to have some more tests cases. Here's some cool things that work with JAX that would be interesting to port over to functorch.