Closed adam-hartshorne closed 5 months ago
Hello @adam-hartshorne ,
Sorry for the very late answer. I have implemented support for torch.func tools in the branch called vmap
, following a previous issue. On this branch, the methods vmap
, grad
, vjp
and jacrev
work, as tested in this script.
However jvp
, jacfwd
and hessian
methods do not work yet, because these ones require forward autodiff. But I have also implemented support for forward autodiff in another branch, so hopefully we can merge everything soon in the main branch and all torch.func tools will be available.
Let me know if this fully answers your question. In fact the new code still uses torch.autograd.Function
class, but with additional setup_context
and vmap
methods.
Support for forward autodiff and all features discussed in this issue are now implemented in pykeops v2.2 that we released today, so I'm closing this issue.
KeOps currently uses torch.autograd.Function, are there any plans to move to the modern-based torch.func interface, which would then allow for jvp, vjp, jacrev etc.