Open Gattocrucco opened 2 years ago
from jax.interpreters import ad
from jax._src.ad_util import stop_gradient_p
import jax
@jax.custom_jvp
def f(x):
return x
@f.defjvp
def f_jvp(x, t):
return x[0], jax.lax.stop_gradient(t[0])
ad.primitive_transposes[stop_gradient_p] = lambda ct, _: [ct]
print(jax.vjp(f, 1.0)[1](1.0)) # ok
Hmmm, this transpose rule cannot stop gradient correctly...It will make stop_gradient
equivalent to identity in reverse-mode autodiff.
Okay I think the stem problem is that:
If we want to stop gradient, we should return zero in transpose rule.
However, if stop_gradient
return zero in transpose rule, and stop_gradient
is used in another function's jvp, then in this function's vjp, stop_gradient
will be transposed to zero function - in your use case, f's vjp will be zero!
IIUC, it is impossible to achieve your goal in this way...
If we want to stop gradient, we should return zero in transpose rule. However, if stop_gradient return zero in transpose rule, and stop_gradient is used in another function's jvp, then in this function's vjp, stop_gradient will be transposed to zero function - in your use case, f's vjp will be zero!
To restore consistency, it is sufficient to wrap ct
in stop_gradient
again:
from jax.interpreters import ad
from jax._src.ad_util import stop_gradient_p
import jax
@jax.custom_jvp
def f(x):
return x ** 2
@f.defjvp
def f_jvp(x, t):
return x[0], jax.lax.stop_gradient(t[0])
ad.primitive_transposes[stop_gradient_p] = lambda ct, _: [jax.lax.stop_gradient(ct)]
print(jax.grad(f)(1.)) # -> 1
print(jax.grad(jax.grad(f))(1.)) # -> 0
Thanks for showing me how to use jax's internals to set a custom transpose for stop_gradient
.
@Gattocrucco Hmmm, if you use stop_gradient(ct)
, you will get jax.jacfwd(jax.jacrev(h))(1.0) == 0.5
instead of 1.
I don't think stop grad on cotangent is correct...
Ach
I still don't have a solution but I have clarified my question: how to implement a function stop_hessian
. The current half-broken attempt is:
import jax
from jax import numpy as jnp
from jax.interpreters import ad
from jax._src.ad_util import stop_gradient_p
ad.primitive_transposes[stop_gradient_p] = lambda ct, _: [jax.lax.stop_gradient(ct)]
@jax.custom_jvp
def stop_hessian(x):
return x
@stop_hessian.defjvp
def stop_hessian_jvp(primals, tangents):
x, = primals
x_dot, = tangents
return x, jax.lax.stop_gradient(x_dot)
def f(x):
return jnp.sin(jnp.cos(x))
def g(x):
return 1/2 * x ** 2
def h(x):
return g(stop_hessian(f(x)))
# h = 1/2 f^2
# h' = f f'
# h'' = f' f' + f f'' =
# ^^^^^ stop_hessian should remove this
x = jnp.arange(4.)
f1 = jax.vmap(jax.grad(f))
f2 = jax.vmap(jax.grad(jax.grad(f)))
print("f' f': ", f1(x) ** 2)
print("f f'': ", f(x) * f2(x))
print('total: ', f1(x) ** 2 + f(x) * f2(x))
print('fwd-fwd:', jax.vmap(jax.jacfwd(jax.jacfwd(h)))(x)) # ok -> f f'
print('rev-fwd:', jax.vmap(jax.jacrev(jax.jacfwd(h)))(x)) # ok -> f f'
print('fwd-rev:', jax.vmap(jax.jacfwd(jax.jacrev(h)))(x)) # wrong -> f f'' (?)
print('rev-rev:', jax.vmap(jax.jacrev(jax.jacrev(h)))(x)) # wrong -> f f'' (?)
f' f': [0. 0.5207154 0.69171137 0.00599572]
f f'': [-0.4546487 -0.42569685 -0.28897595 -0.46805048]
total: [-0.4546487 0.09501857 0.4027354 -0.46205476]
fwd-fwd: [0. 0.5207154 0.69171137 0.00599572]
rev-fwd: [0. 0.5207154 0.69171137 0.00599572]
fwd-rev: [-0.45464867 -0.42569682 -0.28897592 -0.46805045]
rev-rev: [-0.45464867 -0.42569682 -0.28897592 -0.46805045]
I am using a trick to compute the Fisher information matrix from a log likelihood with JAX by taking the Hessian with partially disabled second order derivatives. In particular I apply
stop_gradient
to the tangents in the custom jvps of my linear algebra operations. However this works only with forward derivatives; in reverse mode JAX complains about missing the transpose rule forstop_gradient
.Here is a minimal example with forward derivatives, which works fine:
Changing
jax.jacfwd(jax.jacfwd(h))
to eitherjax.jacfwd(jax.jacrev(h))
,jax.jacrev(jax.jacfwd(h))
orjax.jacrev(jax.jacrev(h))
, produces the following error:Of course I could solve my specific problem by writing directly a jvp for the Fisher information with the handwritten expression, but this trick is very convenient, and it seems to me that it is operationally well defined, so I would expect
stop_gradient
to work.