patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.39k stars 123 forks source link

Second-Order Neural ODEs #114

Open adam-hartshorne opened 2 years ago

adam-hartshorne commented 2 years ago

I recently came across this paper from last years NeurIPS

https://ghliu.github.io/assets/pdf/neurips-snopt-slides.pdf

https://github.com/ghliu/snopt

I was wondering if there are any plans to extend Diffrax to support this improvement to the standard Neural ODE approach.

patrick-kidger commented 2 years ago

At first glance it looks like they're just computing Hessians by composing forward-mode autodiff with reverse-mode autodiff. This is already a standard technique, and it can be done today simply by computing jax.hessian of a diffrax.diffeqsolve.

FWIW I've not convinced by this paper:

  1. The only comparison to make standard autodiff approaches is to a composed reverse-reverse, which is an exceptionally inefficient way to compute Hessians.
  2. It seems to muddle together (a) the computation of Hessians with (b) second-order optimisation techniques.
  3. It's framed in terms of control-theoretic language, which is known to be a needlessly complicated approach. These days we have much better ways of expressing autodiff through diffeqs.

(If I've got the wrong end of the stick somewhere then I'd be happy to be corrected.)

adam-hartshorne commented 2 years ago

I defer to your much greater knowledge in this area (in particular specific criticisms of any claims they make), but my understanding of briefing scanning the paper was that the whole point of what they are doing is not simply taking the hessian, rather forming a much more efficient approximation of such a naive approach.

Here is the paper for reference.

https://arxiv.org/pdf/2109.14158.pdf

And the reviewers back / forth.

https://openreview.net/forum?id=XwetFe0U63c

joglekara commented 1 year ago

I don't think this is fully related but since you mentioned hessian here, I'm running into TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function. for a simple Van Der Pol solve using Tsit5() and the default adjoint method.

Here's the code

def vector_field(t, y, args):
    x1, x2 = y
    x1dot = x2
    x2dot = args["mu"]*(1-x1**2)*x2 - x1

    dstatedt = (x1dot, x2dot)

    return dstatedt

def solve_van_der_pol(mu):    
    return diffrax.diffeqsolve(ODETerm(vector_field), solver=Tsit5(), t0=0., t1=100., dt0=0.1, y0=(1., 0.), args={"mu": mu}, saveat=SaveAt(ts=tsave))

What am I missing? RecursiveCheckpointAdjoint() shouldn't use a custom_vjp, right? Or does it do so re: checkpointing?

patrick-kidger commented 1 year ago

RecursiveCheckpointAdjoint does now use a custom_vjp under the hood.

I'd suggest trying DirectAdjoint instead. The only thing to bear in mind is that this adjoint method gets more expensive as max_steps increases, so make sure to set that value as small as possible.

joglekara commented 1 year ago

Ah okay, I think I forgot/didn't know about DirectAdjoint. This works, TYVM!

bytbox commented 1 year ago

For anyone who (like me) finds the above discussion about how to compute a Hessian of a function involving diffeqsolve a bit opaque, the precise suggestion made is to add an argument adjoint=diffrax.DirectAdjoint() to the diffeqsolve(...) call. Additionally, this is incompatible with max_steps=None, since memory proportional to the maximum number of steps will be allocated at the beginning of the derivative computation.

As I understand it---please correct me if this is wrong---that in cases where you need the Hessian only occasionally, but need the gradient frequently, it will be more efficient to define two different functions, one using DirectAdjoint for use with jax.hessian, and the other using the default RecursiveCheckpointAdjoint for computing first-order derivatives.

Full example:

from diffrax import diffeqsolve, ODETerm, Dopri5, DirectAdjoint
import jax
import jax.numpy as jnp

def f(t, y, args):
    return y**2

term = ODETerm(f)
solver = Dopri5()

def g(x):
    solution = diffeqsolve(term, solver, t0=0, t1=0.1, dt0=0.1, y0=x)
    return jnp.sum(solution.ys[-1])

def g_for_hessian(x):
    solution = diffeqsolve(term, solver, t0=0, t1=0.1, dt0=0.1, y0=x, adjoint=DirectAdjoint())
    return jnp.sum(solution.ys[-1])

print(jax.grad(g)(jnp.array([1., 2.])))
# print(jax.hessian(g)(jnp.array([1., 2.])))            # Gives an error
# print(jax.grad(g_for_hessian)(jnp.array([1., 2.])))   # Inefficient (?)
print(jax.hessian(g_for_hessian)(jnp.array([1., 2.])))
joglekara commented 1 year ago

For anyone who (like me) finds the above discussion about how to compute a Hessian of a function involving diffeqsolve a bit opaque, the precise suggestion made is to add an argument adjoint=diffrax.DirectAdjoint() to the diffeqsolve(...) call. Additionally, this is incompatible with max_steps=None, since memory proportional to the maximum number of steps will be allocated at the beginning of the derivative computation.

As I understand it---please correct me if this is wrong---that in cases where you need the Hessian only occasionally, but need the gradient frequently, it will be more efficient to define two different functions, one using DirectAdjoint for use with jax.hessian, and the other using the default RecursiveCheckpointAdjoint for computing first-order derivatives.

Full example:

from diffrax import diffeqsolve, ODETerm, Dopri5, DirectAdjoint
import jax
import jax.numpy as jnp

def f(t, y, args):
    return y**2

term = ODETerm(f)
solver = Dopri5()

def g(x):
    solution = diffeqsolve(term, solver, t0=0, t1=0.1, dt0=0.1, y0=x)
    return jnp.sum(solution.ys[-1])

def g_for_hessian(x):
    solution = diffeqsolve(term, solver, t0=0, t1=0.1, dt0=0.1, y0=x, adjoint=DirectAdjoint())
    return jnp.sum(solution.ys[-1])

print(jax.grad(g)(jnp.array([1., 2.])))
# print(jax.hessian(g)(jnp.array([1., 2.])))            # Gives an error
# print(jax.grad(g_for_hessian)(jnp.array([1., 2.])))   # Inefficient (?)
print(jax.hessian(g_for_hessian)(jnp.array([1., 2.])))

That makes sense to me!

patrick-kidger commented 1 year ago

Actually, I recently landed some improvements on this front! RecursiveCheckpointAdjoint is now compatible with jax.hessian. In particular you should be able to define a single function for both gradients and Hessians.

I've not extensively benchmarked jax.hessian-of-RecursiveCheckpointAdjoint against jax.hessian-of-DirectAdjoint. Let me know how it goes.

Moreover there is now a dedicated example describing how to compute Hessians. (And for the sake of any future readers: please do consult that for the easiest guide!)

At time of writing there is still one gotcha: you need to replace e.g. Dopri5() with Dopri5(scan_kind="bounded"). This is a flag that DirectAdjoint will set for you automatically, but RecursiveCheckpointAdjoint doesn't for minor performance reasons. I think I know how to make manually setting this unnecessary, but it's pretty technical so I haven't gotten around to it yet!