Closed jakevdp closed 7 months ago
We met exactly this issue, is there an on going PR or any workaround?
Hi @jakevdp
Looks like this issue has been resolved. I executed the mentioned code in Google Colab with JAX 0.4.23 and it works fine.
import jax
@jax.custom_jvp
def f(a, b):
return a * b
f.defjvps(
lambda t, _, a, b: f(t, b),
lambda t, _, a, b: f(a, t),
)
primals, vjp_fun = jax.vjp(f, 1, 1.0)
vjp_fun(primals)
Output:
(array((b'',), dtype=[('float0', 'V')]),
Array(1., dtype=float32, weak_type=True))
Kindly find the gist for reference.
Thank you
Reported in #7092. Related to the comment here: https://github.com/google/jax/blob/97a5719fcb40af7231b5f803f965063538282f8e/jax/interpreters/ad.py#L198-L200
The issue is that in the backward pass over a custom JVP rule, variables that should correspond to cotangents may be erroneously treated as undefined values.
Compact repro: