jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30k stars 2.75k forks source link

VJP error when `custom_vjp` and `custom_vmap` are used together. #14150

Closed chr1sj0nes closed 7 months ago

chr1sj0nes commented 1 year ago

Description

When using both custom_vjp and custom_vmap on the same function (applied in either order), I get the following error:

  File "/build/work/88c64dd2a6f04c75e55857743c6570e5c618/google3/runfiles/google3/third_party/py/jax/_src/api.py", line 2641, in vjp
    return _vjp(
  File "/build/work/88c64dd2a6f04c75e55857743c6570e5c618/google3/runfiles/google3/third_party/py/jax/_src/api.py", line 2650, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/build/work/88c64dd2a6f04c75e55857743c6570e5c618/google3/runfiles/google3/third_party/py/jax/interpreters/ad.py", line 137, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/build/work/88c64dd2a6f04c75e55857743c6570e5c618/google3/runfiles/google3/third_party/py/jax/interpreters/ad.py", line 128, in linearize
    assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
AssertionError

Example code:

@jax.custom_vjp
def add(a, b):
  return a + b

add.defvjp(fwd=lambda a, b: (add(a, b), None), bwd=lambda _, g: (g, g))

add = jax.custom_batching.custom_vmap(add)
add.def_vmap(lambda axis_size, in_batched, *args: (add(*args), True))

What jax/jaxlib version are you using?

No response

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

No response

froystig commented 1 year ago

Thanks for filing! I suspect that this is a consequence of custom vmap not yet supporting reverse-mode autodiff (custom or otherwise). Still, it's useful to have a specific test to track the interaction with custom AD in particular.

cc #9073

rajasekharporeddy commented 7 months ago

Hi @chr1sj0nes

Looks like this issue has been resolved. I executed the mentioned code with JAX version 0.4.23 (CPU and GPU) and 0.3.25 (TPU). It executed without any error.

import jax

@jax.custom_vjp
def add(a, b):
  return a + b

add.defvjp(fwd=lambda a, b: (add(a, b), None), bwd=lambda _, g: (g, g))

add = jax.custom_batching.custom_vmap(add)
add.def_vmap(lambda axis_size, in_batched, *args: (add(*args), True))

Output:

<function __main__.<lambda>(axis_size, in_batched, *args)>

Kindly find the gists on CPU, GPU and TPU for reference.

Thank you.