jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30k stars 2.75k forks source link

Implementing `stop_hessian` in reverse mode (how to transpose `stop_gradient`?) #10994

Open Gattocrucco opened 2 years ago

Gattocrucco commented 2 years ago

I am using a trick to compute the Fisher information matrix from a log likelihood with JAX by taking the Hessian with partially disabled second order derivatives. In particular I apply stop_gradient to the tangents in the custom jvps of my linear algebra operations. However this works only with forward derivatives; in reverse mode JAX complains about missing the transpose rule for stop_gradient.

Here is a minimal example with forward derivatives, which works fine:

import jax
from jax import numpy as jnp

@jax.custom_jvp
def f(x):
    return 1/2 * x ** 2

@f.defjvp
def f_jvp(primals, tangents):
    x, = primals
    x_dot, = jax.lax.stop_gradient(tangents)
    return f(x), x * x_dot

def g(x):
    return 1/2 * x ** 2

def h(x):
    return f(g(x))

# h = 1/2 g^2
# h' = g g'
# h'' = g' g' + g g''
#               ^^^^^ stop_gradient removes this

h2 = jax.vmap(jax.jacfwd(jax.jacfwd(h)))
x = jnp.arange(-8., 9.)
assert jnp.allclose(h2(x), x ** 2)

Changing jax.jacfwd(jax.jacfwd(h)) to either jax.jacfwd(jax.jacrev(h)), jax.jacrev(jax.jacfwd(h)) or jax.jacrev(jax.jacrev(h)), produces the following error:

NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'stop_gradient' not implemented

Of course I could solve my specific problem by writing directly a jvp for the Fisher information with the handwritten expression, but this trick is very convenient, and it seems to me that it is operationally well defined, so I would expect stop_gradient to work.

YouJiacheng commented 2 years ago
from jax.interpreters import ad
from jax._src.ad_util import stop_gradient_p
import jax

@jax.custom_jvp
def f(x):
    return x

@f.defjvp
def f_jvp(x, t):
    return x[0], jax.lax.stop_gradient(t[0])

ad.primitive_transposes[stop_gradient_p] = lambda ct, _: [ct]

print(jax.vjp(f, 1.0)[1](1.0)) # ok

Hmmm, this transpose rule cannot stop gradient correctly...It will make stop_gradient equivalent to identity in reverse-mode autodiff. Okay I think the stem problem is that: If we want to stop gradient, we should return zero in transpose rule. However, if stop_gradient return zero in transpose rule, and stop_gradient is used in another function's jvp, then in this function's vjp, stop_gradient will be transposed to zero function - in your use case, f's vjp will be zero!

IIUC, it is impossible to achieve your goal in this way...

Gattocrucco commented 2 years ago

If we want to stop gradient, we should return zero in transpose rule. However, if stop_gradient return zero in transpose rule, and stop_gradient is used in another function's jvp, then in this function's vjp, stop_gradient will be transposed to zero function - in your use case, f's vjp will be zero!

To restore consistency, it is sufficient to wrap ct in stop_gradient again:

from jax.interpreters import ad
from jax._src.ad_util import stop_gradient_p
import jax

@jax.custom_jvp
def f(x):
    return x ** 2

@f.defjvp
def f_jvp(x, t):
    return x[0], jax.lax.stop_gradient(t[0])

ad.primitive_transposes[stop_gradient_p] = lambda ct, _: [jax.lax.stop_gradient(ct)]

print(jax.grad(f)(1.)) # -> 1
print(jax.grad(jax.grad(f))(1.)) # -> 0

Thanks for showing me how to use jax's internals to set a custom transpose for stop_gradient.

YouJiacheng commented 2 years ago

@Gattocrucco Hmmm, if you use stop_gradient(ct), you will get jax.jacfwd(jax.jacrev(h))(1.0) == 0.5 instead of 1. I don't think stop grad on cotangent is correct...

Gattocrucco commented 2 years ago

Ach

Gattocrucco commented 2 years ago

I still don't have a solution but I have clarified my question: how to implement a function stop_hessian. The current half-broken attempt is:

import jax
from jax import numpy as jnp
from jax.interpreters import ad
from jax._src.ad_util import stop_gradient_p

ad.primitive_transposes[stop_gradient_p] = lambda ct, _: [jax.lax.stop_gradient(ct)]

@jax.custom_jvp
def stop_hessian(x):
    return x

@stop_hessian.defjvp
def stop_hessian_jvp(primals, tangents):
    x, = primals
    x_dot, = tangents
    return x, jax.lax.stop_gradient(x_dot)

def f(x):
    return jnp.sin(jnp.cos(x))

def g(x):
    return 1/2 * x ** 2

def h(x):
    return g(stop_hessian(f(x)))

# h = 1/2 f^2
# h' = f f'
# h'' = f' f' + f f'' =
#               ^^^^^ stop_hessian should remove this

x = jnp.arange(4.)
f1 = jax.vmap(jax.grad(f))
f2 = jax.vmap(jax.grad(jax.grad(f)))
print("f' f':  ", f1(x) ** 2)
print("f f'':  ", f(x) * f2(x))
print('total:  ', f1(x) ** 2 + f(x) * f2(x))
print('fwd-fwd:', jax.vmap(jax.jacfwd(jax.jacfwd(h)))(x)) # ok -> f f'
print('rev-fwd:', jax.vmap(jax.jacrev(jax.jacfwd(h)))(x)) # ok -> f f'
print('fwd-rev:', jax.vmap(jax.jacfwd(jax.jacrev(h)))(x)) # wrong -> f f'' (?)
print('rev-rev:', jax.vmap(jax.jacrev(jax.jacrev(h)))(x)) # wrong -> f f'' (?)
f' f':   [0.         0.5207154  0.69171137 0.00599572]
f f'':   [-0.4546487  -0.42569685 -0.28897595 -0.46805048]
total:   [-0.4546487   0.09501857  0.4027354  -0.46205476]
fwd-fwd: [0.         0.5207154  0.69171137 0.00599572]
rev-fwd: [0.         0.5207154  0.69171137 0.00599572]
fwd-rev: [-0.45464867 -0.42569682 -0.28897592 -0.46805045]
rev-rev: [-0.45464867 -0.42569682 -0.28897592 -0.46805045]