patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.39k stars 124 forks source link

Taking more than one gradient fails with default RecursiveCheckpointAdjoint #332

Open nwlambert opened 10 months ago

nwlambert commented 10 months ago

I am a total beginner with Jax and diffrax, not sure if this is a bug or expected, but if i try to find the second or higher derivative of a solution from diffeqsolve() I get an error. Changing the adjoint to DirectAdjoint() seems to fix the problem.

Minimal working example (using the default ODE example from the diffrax introduction):

import jax.numpy as jnp
import numpy as np
import jax
from diffrax import diffeqsolve, ODETerm, Dopri5,  DirectAdjoint

z = 2.3
t = 1.

def rhot(z):
    def f(t, y, args):
        return -z*y

    term = ODETerm(f)
    solver = Dopri5()
    y0 = jnp.array([2., 3.])
    solution = diffeqsolve(term, solver, t0=0, t1=t, dt0=0.1, y0=y0) 
    #solution = diffeqsolve(term, solver, t0=0, t1=t, dt0=0.1, y0=y0, adjoint = DirectAdjoint()) #changing the adjoint fixes it
    return solution.ys[0][0]

drhozdz = jax.grad(rhot,argnums = 0)
d2rhozdz = jax.grad(drhozdz,argnums = 0)

print("expected state ", np.exp(-z*t)*2.)
print("found state ", rhot(z))

print("expected ", -2.*t*np.exp(-z*t))
print("found 1st deriative ", drhozdz(z))

print("expected 2nd ", 2.*t**2*np.exp(-z*t))
print("found 2nd derivative ", d2rhozdz(z))  #fails with default adjoint

The error returned is: "print("found 2nd deriative ", d2rhozdz(z)) #fails with default adjoint ^^^^^^^^^^^ ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop."

patrick-kidger commented 10 months ago

Yup, I'm afraid this is expected. RecusiveCheckpointAdjoint does some smart things under-the-hood to be very efficient when computing specifically first-order gradients, but unfortunately this also makes it incompatible with certain kinds of higher-order autodiff.

First of all, when looking to compute the Hessian, it is usually more efficient to use forward-over-reverse rather than reverse-over-reverse (and indeed this is what jax.hessian does). RecursiveCheckpointAdjoint should actually be compatible with that in most cases.

But nonetheless, in the general case, using DirectAdjoint is indeed the appropriate fix. (And handling edge cases like this is the reason it exists,)

You might also like the example on second-order sensitivies from the documentation.


I'm going to tag this under "refactor" as this could probably do with a more informative error message.

nwlambert commented 10 months ago

Thanks for the quick reply, I missed that documentation, it was very helpful.

Playing around a bit with a more complex example I am struggling with, I see what you mean... doing forward-over-reverse with RecusiveCheckpointAdjoint() works and seems both faster and more memory efficient than using DirectAdjoint(), so that was extremely useful! thanks!