jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30k stars 2.75k forks source link

Reverse-mode autodiff of custom_vjp fails #7098

Closed jakevdp closed 7 months ago

jakevdp commented 3 years ago

Reported in #7092. Related to the comment here: https://github.com/google/jax/blob/97a5719fcb40af7231b5f803f965063538282f8e/jax/interpreters/ad.py#L198-L200

The issue is that in the backward pass over a custom JVP rule, variables that should correspond to cotangents may be erroneously treated as undefined values.

Compact repro:

import jax

@jax.custom_jvp
def f(a, b):
    return a * b

f.defjvps(
    lambda t, _, a, b: f(t, b),
    lambda t, _, a, b: f(a, t),
)

primals, vjp_fun = jax.vjp(f, 1, 1.0)
vjp_fun(primals)
Traceback (most recent call last):
  File "tmp.py", line 13, in <module>
    vjp_fun(primals)
  File "/Users/vanderplas/github/google/jax/jax/_src/api.py", line 1881, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/Users/vanderplas/github/google/jax/jax/interpreters/ad.py", line 121, in unbound_vjp
    arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
  File "/Users/vanderplas/github/google/jax/jax/interpreters/ad.py", line 218, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
  File "/Users/vanderplas/github/google/jax/jax/_src/custom_derivatives.py", line 371, in _custom_jvp_call_jaxpr_transpose
    return ad.backward_pass(fun_jaxpr.jaxpr, fun_jaxpr.consts, args, cts)
  File "/Users/vanderplas/github/google/jax/jax/interpreters/ad.py", line 218, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
  File "/Users/vanderplas/github/google/jax/jax/_src/lax/lax.py", line 2740, in _mul_transpose
    assert ad.is_undefined_primal(x) ^ ad.is_undefined_primal(y)
AssertionError
mavenlin commented 2 years ago

We met exactly this issue, is there an on going PR or any workaround?

rajasekharporeddy commented 7 months ago

Hi @jakevdp

Looks like this issue has been resolved. I executed the mentioned code in Google Colab with JAX 0.4.23 and it works fine.

import jax

@jax.custom_jvp
def f(a, b):
    return a * b

f.defjvps(
    lambda t, _, a, b: f(t, b),
    lambda t, _, a, b: f(a, t),
)

primals, vjp_fun = jax.vjp(f, 1, 1.0)
vjp_fun(primals)

Output:

(array((b'',), dtype=[('float0', 'V')]),
 Array(1., dtype=float32, weak_type=True))

Kindly find the gist for reference.

Thank you