jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.01k stars 2.75k forks source link

Cannot compute Hessian of `custom_vjp`. #16049

Closed patrick-kidger closed 1 year ago

patrick-kidger commented 1 year ago

Description

import jax
import jax.numpy as jnp

@jax.custom_vjp
def f(x, y):
  return jnp.sin(x) * y

def f_fwd(x, y):
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)
jax.hessian(f)(1., 2.)
# TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

This should be doable by computing the JVP of f_bwd.

It looks like it's another missing DCE. For example this works:

import jax.interpreters.partial_eval as pe
jaxpr = jax.make_jaxpr(jax.jacrev(f))(1., 2.)
dce_jaxpr, _ = pe.dce_jaxpr(jaxpr.jaxpr, [True] * len(jaxpr.out_avals), instantiate=True)
jaxpr = jax.core.ClosedJaxpr(dce_jaxpr, jaxpr.consts)
fn = jax.core.jaxpr_as_fun(jaxpr)
jax.jacfwd(fn)(1., 2.)

The only thing I'm not sure about is how to implement this in an eager-compatible way. Neither jax.jacfwd nor jax.jvp actually stage things out into a jaxpr, I don't think.

What jax/jaxlib version are you using?

4.10

mattjj commented 1 year ago

This should be doable by computing the JVP of f_bwd.

Hm, I don't think that's right: in gneeral we need to apply jvp to the whole forward and backward pass, i.e. both f_fwd and f_bwd, where f_fwd may include an application of f which is a custom_vjp function. It happens to work in this case because of two special cases occurring, and if either were not present then the DCE wouldn't help:

First, in the example's f_fwd the primal output is not used as a residual, but in many cases it is. We can simulate such a case here by adapting the f_fwd:

def f_fwd(x, y):
  z = f(x, y)
  return z, (z, jnp.sin(x), y)

Running that gives us an error even with the DCE included.

The second special case is that we aren't using the primal output of f downstream, i.e. we aren't differentiating a composition involving f but rather just f by itself. If we adapt the original example (i.e. don't use the f_fwd modification just above) by changing the make_jaxpr line to something like

jaxpr = jax.make_jaxpr(jax.jacrev(lambda x, y: jnp.sin(f(x, y))))(1., 2.)

then again the DCE step doesn't rescue us.

I think the issue here is just that jax.hessian is defined using forward-over-reverse, which requires jvp, which rules out custom_vjp. People hitting this issue should just use an alternative way to compute the Hessian, like reverse-over-reverse hessian = lambda f: jax.jacrev(jax.jacrev(f)).

What do you think?

mattjj commented 1 year ago

We could make JVP-of-custom_vjp work via automatic transposition, but that also runs into the eager problem (i.e. we can only transpose functions for which we can form a jaxpr, which isn't implied by the usual custom_vjp assumptions).

patrick-kidger commented 1 year ago

Ah, bother -- agreed.

Neither double-jacrev nor jvp-of-custom_vjp-via-transposition would work for me, as my VJP uses a lax.while_loop. And the former would be really inefficient too.

( Context: my use case is the reverse-mode capable while_loop I wrote a bit back. As a custom_vjp it supports first-order reverse-mode autodiff, but not forward-mode or higher-order. The rest of JAX allows you to arbitrarily compose transformations, and I'd really like to bring while_loop up to meet that bar! Probably the only approach will be a custom higher-order primitive, with checkpointing happening in the partial evaluation rule, but that's not easy. )

mattjj commented 1 year ago

What you said makes sense! We'll crack loops together... someday.

Should we close this issue?

patrick-kidger commented 1 year ago

Actually, I have had one realisation. Right now the custom_jvp fwd pass for eqxi.while_loop deliberately doesn't call back into itself. This means that the particular case of Hessians-of-Diffrax can be handled via DCE!

That suggests perhaps wanting a jax.dce public API wrapper for pe.dce_jaxpr. Either way, I'll close this issue for now as that's really a different ask, and one which I suspect this is probably more jex than JAX.

(I'm guessing pe.dce_jaxpr is stable enough that I can have Diffrax depend on it directly in the mean time.)

mlysy commented 11 months ago

I would have set up the custom_vjp() above as follows:

def _f(x, y):
    "Returns any intermediate values used in  f_bwd."
    sin_x = jnp.sin(x)
    return sin_x * y, sin_x

@jax.custom_vjp
def f(x, y):
    return _f(x, y)[0]

def f_fwd(x, y):
    ans, sin_x = _f(x, y)
    return ans, (jnp.cos(x), sin_x, y)

def f_bwd(res, g):
    cos_x, sin_x, y = res
    return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)
jax.hessian(f)(1., 2.) # works and gives same value as jax.hessian(lambda x, y: jnp.sin(x) * y)(1., 2.)

Apologies if this was clear from the discussion above, but is there something wrong with this general approach (i.e., returning intermediate computations from f that are needed for f_bwd so as to avoid duplicating them in f_fwd)?

For example, is jax.hessian() doing a (typically unwanted) JVP through f_fwd with this? Is the above approach OK as long as I get the hessian via reverse-over-reverse?

patrick-kidger commented 11 months ago

It's typical for f_fwd to call back into the the custom_vjp-wrapped f.

The reason: suppose you've already done a single reverse-mode AD call, so that f_fwd and f_bwd have been traced through. If you wanted to then do a second (higher-order) reverse-mode AD call -- through f_fwd and f_bwd -- you'd typically want to hit the custom_vjp-wrapped version of f, not the underlying _f

For example, is jax.hessian() doing a (typically unwanted) JVP through f_fwd with this?

Yes it's doing a JVP; no it's not unwanted.

Is the above approach OK as long as I get the hessian via reverse-over-reverse?

No, because then you hit the scenario I just laid out: you end up doing reverse-of-_f, not reverse-of-f.

mlysy commented 11 months ago

Thanks very much for the additional explanation.

In my use-case I have something like

def f(x):
    y = _dont_autodiff(x)   # involves a loop
    z = _do_autodiff(x, y)  # no loops; autodiff OK
    return z

The inner functions will never get called outside of f(), and the custom VJP rule f_bwd() is a simple (autodiffable) function of y and z.

Since _dont_autodiff() is never used for anything except the internals of f(), I didn't bother working out a custom VJP for it. Can I avoid the pitfalls above by defining a custom VJP rule for _f(), but ignoring the internal computations as outputs in the rule? Here's what I mean:

@jax.custom_vjp
def _f(x, y):
    "Returns desired output and intermediate calculations required for reverse-mode rule for f()."
    sin_x = jnp.sin(x)
    ans = sin_x * y
    return ans, sin_x

def _f_fwd(x, y):
    out, sin_x = _f(x, y)
    return (out, sin_x), (sin_x, jnp.cos(x), y)

def _f_bwd(res, out_bar):
    sin_x, cos_x, y = res
    ans_bar, sin_x_bar = out_bar # ignore sin_x_bar completely
    return (cos_x * y * ans_bar, sin_x * ans_bar)

_f.defvjp(_f_fwd, _f_bwd)

def f(x, y):
    return _f(x, y)[0]

# jax.hessian(f)(1., 2.)      # fails
jax.grad(jax.grad(f))(1., .2) # correct answer

# check composition of functions
def g(x, y):
    return jnp.exp(f(x, y))

jax.grad(jax.grad(g))(1., 2.) # also gives correct answer 

This seems to work, but I'm shook from my previous don't-duplicate-calculations attempt...

patrick-kidger commented 11 months ago

If I'm understanding you correctly, then what you're doing is fine. I think the only possible issue is that a double-reverse is usually less efficient than forward-over-reverse, but with this it sounds like you know what you're doing / that that really is what you want here.

mlysy commented 11 months ago

Again, thanks very much for your feedback on this.

Ideally I would like to do forward-over-reverse for my own use-case so as to maximize efficiency. Unfortunately I don't understand the JAX internals well enough to adapt your DCE example above anytime soon :(

patrick-kidger commented 11 months ago

If you want to support forward-over-reverse (and don't need to worry about reverse-over-reverse) then simply make sure f_fwd (and f_bwd) is JVP'able. In particular, it musn't call anything with a custom_vjp -- such as the original wrapped f.

Removing that inner custom_vjp will of course mean you mustn't do reverse-over-reverse. Right now, you have to pick whether you want to support forward-reverse or reverse-reverse.

Finally, if you original wrapped f isn't even JVP'able then you may need to use a custom_jvp inside your f_fwd.

mlysy commented 11 months ago

My f() involves an iterative solver, so as per your explanation above I would need to be mindful not to JVP through it directly.

That being said, upon carefully rereading the part on implicit differentiation of iterative solvers in the JAX custom derivatives tutorial, I'm realizing that VJPs aren't the only way to achieve this: JVPs can also be used, though they are perhaps less computationally efficient depending on the problem. In fact the jax.lax.custom_root() function referenced at the end of that section seems to use custom JVPs.

I guess I'll have to see which is faster for my use-case, maybe bearing in mind the benefit of supporting whatever-over-whatever derivatives like the rest of JAX...

patrick-kidger commented 11 months ago

Aha, in that case -- implicit differentiation with a custom_jvp is indeed the way to go. JAX will actually be able to synthesise the vjp from this automatically (and this will be efficient too).

jax.lax.custom_root is one option. Another implementation of this (that I prefer) is also available at optimistix.internal.implicit_jvp.

As this is implemented in terms of a linear solve, then this should be compatible with arbitrarily high order AD.

(Also, if your iterative method is a nonlinear solver, then you might like to implement it as a custom solver in Optimistix directly.)

mlysy commented 11 months ago

Wonderful! Thanks for pointing me to Optimistix -- I actually just need a Newton root finder so it seems I'm all set :)