Closed tianjuxue closed 1 year ago
Some more information:
The above code works when replacing jax.jacfwd
with jax.jvp
.
Also works when replacing return ys, ys_dot
with return primals[0], ys_dot
It seems like when the primal computation depends on the tangents, we get a vmap error when computing full Jacobians.
Thanks for raising the issue!
It seems like when the primal computation depends on the tangents, we get a vmap error when computing full Jacobians.
I think that's the correct synopsis... I'm having trouble deciding whether this is a true bug, or just reflects an unintended use. Are there any non-contrived examples in which the primals depend on the tangents? It seems to me that in itself would indicate the JVP function is malformed.
Perhaps the best fix here would be to raise a clearer error message in this case?
Thanks for raising the issue!
It seems like when the primal computation depends on the tangents, we get a vmap error when computing full Jacobians.
I think that's the correct synopsis... I'm having trouble deciding whether this is a true bug, or just reflects an unintended use. Are there any non-contrived examples in which the primals depend on the tangents? It seems to me that in itself would indicate the JVP function is malformed.
Perhaps the best fix here would be to raise a clearer error message in this case?
Thanks for the clarification. For a non-contrived example, the odeint_rk4_jvp
function defined in this tutorial would cause the same issue. It'd be great if this issue can be fixed. Otherwise, I think it's possible to rewrite odeint_rk4_jvp
so that the primal computation is decoupled from the tangents, but in this way we perhaps can't use the augmented ODE system trick in that tutorial example. The trick is also used in jax/experimental/ode.py
OK, thanks. Agreed that's unfortunate: just to say in words, basically, if you use jacfwd
and your JVP calls a function that requires the primals and tangents to be stacked together, the vmapped tangents will "infect" the primals with their batching tracer, and the jacfwd
will fail (since it doesn't expect the output primals to be batched).
This is less of a bug in functionality, more an issue with under-documentation and error messages. For JVPs, jax requires that the primal outputs are independent of input tangents.
We don't check this directly in JVPs or forward mode, but we require it once other transformations such as reverse mode take place, which is when the resulting errors happen (and can be confusing). It is possible today to get away with writing a JVP whose primal output depends on input tangents, and not all transformations will take issue with it (e.g. jvp
), although this is not something we've intended to promise.
We haven't figured out how to check this property of JVPs upfront, but we can try to mitigate the surprises with better documentation in the context of custom_jvp
, and perhaps by adding mention of this to the downstream error messages that it often causes.
We are working with
jax.custom_jvp
@denizokt and want to report the unexpected error with the following code that can reproduce the error.Error message: ValueError: vmap has mapped output but out_axes is None
Thanks!