getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

Modern Gradient Functionality? #319

Closed adam-hartshorne closed 5 months ago

adam-hartshorne commented 11 months ago

KeOps currently uses torch.autograd.Function, are there any plans to move to the modern-based torch.func interface, which would then allow for jvp, vjp, jacrev etc.

joanglaunes commented 9 months ago

Hello @adam-hartshorne , Sorry for the very late answer. I have implemented support for torch.func tools in the branch called vmap, following a previous issue. On this branch, the methods vmap, grad, vjp and jacrev work, as tested in this script. However jvp, jacfwd and hessian methods do not work yet, because these ones require forward autodiff. But I have also implemented support for forward autodiff in another branch, so hopefully we can merge everything soon in the main branch and all torch.func tools will be available. Let me know if this fully answers your question. In fact the new code still uses torch.autograd.Function class, but with additional setup_context and vmap methods.

joanglaunes commented 5 months ago

Support for forward autodiff and all features discussed in this issue are now implemented in pykeops v2.2 that we released today, so I'm closing this issue.