Open SNMS95 opened 1 year ago
@mblondel Is this going to be implemented ?.
The reason why I need this because I have a blackbox solver from PETSc that I would like to use to solve the linear equation.
Do you know if I can use jax.lax.custom_linear_solver
instead?
Is this in the plans? It would be great to have this to be able to nest optimizers easily.
I was trying to use the
custom_root
decorator to differentiate through a solver. When I try to take the gradients, it works well. However, if I try to usejax.hessian
, I get the error that "cannot use forward-mode autodiff with a custom_vjp function". When searching the JAX documents, it shows that we can use both modes of differentiation if and only if we usecustom_jvp
instead ofcustom_vjp
.I saw that internally,
custom_root
implements only acustom_vjp
rule. Is there any way to to choose thecustom_jvp
rule instead ?A minimal example is as follows:
with the error