patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.34k stars 123 forks source link

How to evaluate derivatives of diffrax-solved, equilibrated functions? #432

Open francesco-innocenti opened 1 month ago

francesco-innocenti commented 1 month ago

Hi!

This is a follow-up on #181. The use case is to evaluate the derivatives (e.g. gradient, hessian) of some loss function $\mathcal{L}$ with respect to some variable $\theta$ at the gradient equilibrium of that loss with respect to some other variable $\partial \mathcal{L}/\partial{y} \approx 0$. Mathematically this would be something like

$\LARGE{\frac{\partial \mathcal{L}(y; \theta)}{\partial \theta}|_{\frac{\partial \mathcal{L}}{\partial y}\approx 0}}$

In code, building on your snippet from #181

def L(y, theta):  # some loss
    ... 

def dLdy(t, y, args):  # vector field for gradient system
    return -jax.grad(L)(y, args)

def solve_y(y0, theta):
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(vector_field),
        y0=y0,
        args=theta,
        ...
    )
    return sol.ys

def dLdtheta(self, y, theta):
    return grad(L, argnums=(1))(y, theta)

Given these, I could just solve for y and then take the gradient wrt theta, like so

y_sol = solve_y(y0, theta):
theta_grad = dLdtheta(y_sol, theta)

However, this ignores the dependencies between y and theta that occur in the integration of the gradient system. So ideally i would like to take the gradient of the loss where within the same loss I'm solving for y

def equilibrated_L(y0, theta):  # equilibrated loss
    y_sol = solve_y(y0, theta)
    ...
    return L

def dLdtheta(self, y, theta):
    return grad(equilibrated_L, argnums=(1))(y, theta)

theta_grad = dLdtheta(y, theta)

But using this approach I get a # TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function..

Hope all of that makes sense. Maybe I am missing something. For example, I wonder whether this could be a use case for an adjoint method?

Thanks!

patrick-kidger commented 1 month ago

What you're doing looks reasonable to me.

The error you're getting is coming from using either jax.jvp or jax.jacfwd.

Unfortunately there's no good way in JAX to create something that has both custom forward-mode and custom reverse-mode autodiff. By default, Diffrax provides custom reverse-mode autodiff for diffeqsolve. You might like to try the solution of https://github.com/patrick-kidger/optimistix/pull/51#issuecomment-2105948574, which provides an alternate adjoint method that supports forward mode only instead.

lockwo commented 1 month ago

Unfortunately there's no good way in JAX to create something that has both custom forward-mode and custom reverse-mode autodiff.

Why is this? This doesn't seem fundamentally impossible

patrick-kidger commented 1 month ago

jax.custom_jvp and jax.custom_vjp cannot both be applied: one of them has to be on the outside, and this is the one that autodiff sees.

The only way to do this is to define a custom primitive, but that's not exactly easy! In particular for our use case here, where we our custom autodiff is a loop over steps. (eqx.internal.while_loop(..., kind="checkpointed")).

johannahaffner commented 1 month ago

jax.hessian composes jacfwd(jacrev(function)).

jacrev(jacrev(function)), or jacfwd(jacfwd(function)) both work. So you can adapt to the adjoint of your choice.