google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.93k stars 2.74k forks source link

Unexpected error related to `jax.custom_jvp` #7972

Closed tianjuxue closed 1 year ago

tianjuxue commented 2 years ago

We are working with jax.custom_jvp @denizokt and want to report the unexpected error with the following code that can reproduce the error.

import jax.numpy as jnp
import jax

def identify_fn(x):
  return x

identify_fn = jax.custom_jvp(identify_fn)

@identify_fn.defjvp
def identify_fn_jvp(primals, tangents):
  aug_init_state = jnp.concatenate((*primals, *tangents))
  ys = aug_init_state[:3]
  ys_dot = aug_init_state[3:]
  return ys, ys_dot
  # Won't cause an error if we do the following instead
  # return (*primals, *tangents)

jac = jax.jacfwd(identify_fn)(jnp.array([5., 5., 5.]))

Error message: ValueError: vmap has mapped output but out_axes is None

Thanks!

denizokt commented 2 years ago

Some more information:

The above code works when replacing jax.jacfwd with jax.jvp. Also works when replacing return ys, ys_dot with return primals[0], ys_dot

It seems like when the primal computation depends on the tangents, we get a vmap error when computing full Jacobians.

jakevdp commented 2 years ago

Thanks for raising the issue!

It seems like when the primal computation depends on the tangents, we get a vmap error when computing full Jacobians.

I think that's the correct synopsis... I'm having trouble deciding whether this is a true bug, or just reflects an unintended use. Are there any non-contrived examples in which the primals depend on the tangents? It seems to me that in itself would indicate the JVP function is malformed.

Perhaps the best fix here would be to raise a clearer error message in this case?

tianjuxue commented 2 years ago

Thanks for raising the issue!

It seems like when the primal computation depends on the tangents, we get a vmap error when computing full Jacobians.

I think that's the correct synopsis... I'm having trouble deciding whether this is a true bug, or just reflects an unintended use. Are there any non-contrived examples in which the primals depend on the tangents? It seems to me that in itself would indicate the JVP function is malformed.

Perhaps the best fix here would be to raise a clearer error message in this case?

Thanks for the clarification. For a non-contrived example, the odeint_rk4_jvp function defined in this tutorial would cause the same issue. It'd be great if this issue can be fixed. Otherwise, I think it's possible to rewrite odeint_rk4_jvp so that the primal computation is decoupled from the tangents, but in this way we perhaps can't use the augmented ODE system trick in that tutorial example. The trick is also used in jax/experimental/ode.py

jakevdp commented 2 years ago

OK, thanks. Agreed that's unfortunate: just to say in words, basically, if you use jacfwd and your JVP calls a function that requires the primals and tangents to be stacked together, the vmapped tangents will "infect" the primals with their batching tracer, and the jacfwd will fail (since it doesn't expect the output primals to be batched).

froystig commented 2 years ago

This is less of a bug in functionality, more an issue with under-documentation and error messages. For JVPs, jax requires that the primal outputs are independent of input tangents.

We don't check this directly in JVPs or forward mode, but we require it once other transformations such as reverse mode take place, which is when the resulting errors happen (and can be confusing). It is possible today to get away with writing a JVP whose primal output depends on input tangents, and not all transformations will take issue with it (e.g. jvp), although this is not something we've intended to promise.

We haven't figured out how to check this property of JVPs upfront, but we can try to mitigate the surprises with better documentation in the context of custom_jvp, and perhaps by adding mention of this to the downstream error messages that it often causes.