Open francesco-innocenti opened 1 month ago
What you're doing looks reasonable to me.
The error you're getting is coming from using either jax.jvp
or jax.jacfwd
.
Unfortunately there's no good way in JAX to create something that has both custom forward-mode and custom reverse-mode autodiff. By default, Diffrax provides custom reverse-mode autodiff for diffeqsolve
. You might like to try the solution of https://github.com/patrick-kidger/optimistix/pull/51#issuecomment-2105948574, which provides an alternate adjoint
method that supports forward mode only instead.
Unfortunately there's no good way in JAX to create something that has both custom forward-mode and custom reverse-mode autodiff.
Why is this? This doesn't seem fundamentally impossible
jax.custom_jvp
and jax.custom_vjp
cannot both be applied: one of them has to be on the outside, and this is the one that autodiff sees.
The only way to do this is to define a custom primitive, but that's not exactly easy! In particular for our use case here, where we our custom autodiff is a loop over steps. (eqx.internal.while_loop(..., kind="checkpointed")
).
jax.hessian
composes jacfwd(jacrev(function))
.
jacrev(jacrev(function))
, or jacfwd(jacfwd(function))
both work. So you can adapt to the adjoint of your choice.
Hi!
This is a follow-up on #181. The use case is to evaluate the derivatives (e.g. gradient, hessian) of some loss function $\mathcal{L}$ with respect to some variable $\theta$ at the gradient equilibrium of that loss with respect to some other variable $\partial \mathcal{L}/\partial{y} \approx 0$. Mathematically this would be something like
$\LARGE{\frac{\partial \mathcal{L}(y; \theta)}{\partial \theta}|_{\frac{\partial \mathcal{L}}{\partial y}\approx 0}}$
In code, building on your snippet from #181
Given these, I could just solve for y and then take the gradient wrt theta, like so
However, this ignores the dependencies between y and theta that occur in the integration of the gradient system. So ideally i would like to take the gradient of the loss where within the same loss I'm solving for y
But using this approach I get a
# TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
.Hope all of that makes sense. Maybe I am missing something. For example, I wonder whether this could be a use case for an adjoint method?
Thanks!