Open adam-hartshorne opened 2 years ago
At first glance it looks like they're just computing Hessians by composing forward-mode autodiff with reverse-mode autodiff. This is already a standard technique, and it can be done today simply by computing jax.hessian
of a diffrax.diffeqsolve
.
FWIW I've not convinced by this paper:
(If I've got the wrong end of the stick somewhere then I'd be happy to be corrected.)
I defer to your much greater knowledge in this area (in particular specific criticisms of any claims they make), but my understanding of briefing scanning the paper was that the whole point of what they are doing is not simply taking the hessian, rather forming a much more efficient approximation of such a naive approach.
Here is the paper for reference.
https://arxiv.org/pdf/2109.14158.pdf
And the reviewers back / forth.
I don't think this is fully related but since you mentioned hessian
here, I'm running into TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
for a simple Van Der Pol solve using Tsit5() and the default adjoint method.
Here's the code
def vector_field(t, y, args):
x1, x2 = y
x1dot = x2
x2dot = args["mu"]*(1-x1**2)*x2 - x1
dstatedt = (x1dot, x2dot)
return dstatedt
def solve_van_der_pol(mu):
return diffrax.diffeqsolve(ODETerm(vector_field), solver=Tsit5(), t0=0., t1=100., dt0=0.1, y0=(1., 0.), args={"mu": mu}, saveat=SaveAt(ts=tsave))
What am I missing? RecursiveCheckpointAdjoint()
shouldn't use a custom_vjp
, right? Or does it do so re: checkpointing?
RecursiveCheckpointAdjoint
does now use a custom_vjp
under the hood.
I'd suggest trying DirectAdjoint
instead. The only thing to bear in mind is that this adjoint method gets more expensive as max_steps
increases, so make sure to set that value as small as possible.
Ah okay, I think I forgot/didn't know about DirectAdjoint
. This works, TYVM!
For anyone who (like me) finds the above discussion about how to compute a Hessian of a function involving diffeqsolve a bit opaque, the precise suggestion made is to add an argument adjoint=diffrax.DirectAdjoint()
to the diffeqsolve(...)
call. Additionally, this is incompatible with max_steps=None
, since memory proportional to the maximum number of steps will be allocated at the beginning of the derivative computation.
As I understand it---please correct me if this is wrong---that in cases where you need the Hessian only occasionally, but need the gradient frequently, it will be more efficient to define two different functions, one using DirectAdjoint
for use with jax.hessian
, and the other using the default RecursiveCheckpointAdjoint
for computing first-order derivatives.
Full example:
from diffrax import diffeqsolve, ODETerm, Dopri5, DirectAdjoint
import jax
import jax.numpy as jnp
def f(t, y, args):
return y**2
term = ODETerm(f)
solver = Dopri5()
def g(x):
solution = diffeqsolve(term, solver, t0=0, t1=0.1, dt0=0.1, y0=x)
return jnp.sum(solution.ys[-1])
def g_for_hessian(x):
solution = diffeqsolve(term, solver, t0=0, t1=0.1, dt0=0.1, y0=x, adjoint=DirectAdjoint())
return jnp.sum(solution.ys[-1])
print(jax.grad(g)(jnp.array([1., 2.])))
# print(jax.hessian(g)(jnp.array([1., 2.]))) # Gives an error
# print(jax.grad(g_for_hessian)(jnp.array([1., 2.]))) # Inefficient (?)
print(jax.hessian(g_for_hessian)(jnp.array([1., 2.])))
For anyone who (like me) finds the above discussion about how to compute a Hessian of a function involving diffeqsolve a bit opaque, the precise suggestion made is to add an argument
adjoint=diffrax.DirectAdjoint()
to thediffeqsolve(...)
call. Additionally, this is incompatible withmax_steps=None
, since memory proportional to the maximum number of steps will be allocated at the beginning of the derivative computation.As I understand it---please correct me if this is wrong---that in cases where you need the Hessian only occasionally, but need the gradient frequently, it will be more efficient to define two different functions, one using
DirectAdjoint
for use withjax.hessian
, and the other using the defaultRecursiveCheckpointAdjoint
for computing first-order derivatives.Full example:
from diffrax import diffeqsolve, ODETerm, Dopri5, DirectAdjoint import jax import jax.numpy as jnp def f(t, y, args): return y**2 term = ODETerm(f) solver = Dopri5() def g(x): solution = diffeqsolve(term, solver, t0=0, t1=0.1, dt0=0.1, y0=x) return jnp.sum(solution.ys[-1]) def g_for_hessian(x): solution = diffeqsolve(term, solver, t0=0, t1=0.1, dt0=0.1, y0=x, adjoint=DirectAdjoint()) return jnp.sum(solution.ys[-1]) print(jax.grad(g)(jnp.array([1., 2.]))) # print(jax.hessian(g)(jnp.array([1., 2.]))) # Gives an error # print(jax.grad(g_for_hessian)(jnp.array([1., 2.]))) # Inefficient (?) print(jax.hessian(g_for_hessian)(jnp.array([1., 2.])))
That makes sense to me!
Actually, I recently landed some improvements on this front! RecursiveCheckpointAdjoint
is now compatible with jax.hessian
. In particular you should be able to define a single function for both gradients and Hessians.
I've not extensively benchmarked jax.hessian
-of-RecursiveCheckpointAdjoint
against jax.hessian
-of-DirectAdjoint
. Let me know how it goes.
Moreover there is now a dedicated example describing how to compute Hessians. (And for the sake of any future readers: please do consult that for the easiest guide!)
At time of writing there is still one gotcha: you need to replace e.g. Dopri5()
with Dopri5(scan_kind="bounded")
. This is a flag that DirectAdjoint
will set for you automatically, but RecursiveCheckpointAdjoint
doesn't for minor performance reasons. I think I know how to make manually setting this unnecessary, but it's pretty technical so I haven't gotten around to it yet!
I recently came across this paper from last years NeurIPS
https://ghliu.github.io/assets/pdf/neurips-snopt-slides.pdf
https://github.com/ghliu/snopt
I was wondering if there are any plans to extend Diffrax to support this improvement to the standard Neural ODE approach.