Closed chr1sj0nes closed 7 months ago
Thanks for filing! I suspect that this is a consequence of custom vmap not yet supporting reverse-mode autodiff (custom or otherwise). Still, it's useful to have a specific test to track the interaction with custom AD in particular.
cc #9073
Hi @chr1sj0nes
Looks like this issue has been resolved. I executed the mentioned code with JAX version 0.4.23 (CPU and GPU) and 0.3.25 (TPU). It executed without any error.
import jax
@jax.custom_vjp
def add(a, b):
return a + b
add.defvjp(fwd=lambda a, b: (add(a, b), None), bwd=lambda _, g: (g, g))
add = jax.custom_batching.custom_vmap(add)
add.def_vmap(lambda axis_size, in_batched, *args: (add(*args), True))
Output:
<function __main__.<lambda>(axis_size, in_batched, *args)>
Kindly find the gists on CPU, GPU and TPU for reference.
Thank you.
Description
When using both
custom_vjp
andcustom_vmap
on the same function (applied in either order), I get the following error:Example code:
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response