Open f0uriest opened 1 year ago
You might be interested in this,
https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py
I seem to get the same error when using scan, so it may not be unique to while loops.
This is interesting. I have no problems using while_loop with custom JVP rule in my case but don't know what is happening in your case (I reproduced the error). Maybe using dummy_integrate
twice in _dummy_integrate_jvp
caused the error. I defined dummy_integrate2
, which is exactly the same as dummy_integrate
and replaced the second dummy_integrate
with dummy_integrate2
. I got the following error:
ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.
This was interesting because I thought JAX does not differentiate _dummy_integrate_jvp
function!
@f0uriest -- I think the issue you're seeing here is because the df
function (defined inside the custom JVP rule) closes over argsdot
, but is then passed as a nondiff_argnum
to dummy_integrate
(the one that returns f2
).
When using a custom_jvp
or custom_vjp
rule, you must ensure any function-valued arguments don't close over any tracers. (Typically I'd recommend always making them global functions as a way to avoid any possibility of this.) You should rewrite things so that the argsdot
tracers are passed as formal arguments.
@ToshiyukiBandai -- Right, you're now bumping into the second issue with the original code (and which f0uriest will hit after having fixed the above issue)!
Explaining this one requires knowing a bit about JAX internals. Buckle up, this gets a bit complicated.
JAX performs VJPs (backpropagation) by first figuring out what the tangent part of a JVP rule looks like, and then transposing it (loosely speaking, the "transpose" here is the "run it backwards" part of backpropagation). In terms of public JAX APIs, what this corresponds to is that jax.grad
is built by applying jax.linear_transpose
to the tangent inputs and outputs of jax.jvp
.
Unfortunately, jax.lax.while_loop
does not support transposition. Backpropagating through a while loop would require saving the result of every step. As it's a while loop, the number of steps is not known at compile time. That means the amount of memory needed is not known at compile time. And the XLA compiler only performs static memory allocation. Thus, no backpropagating through jax.lax.while_loop
.
The error message you're seeing is to help catch the common case when jax.grad(jax.lax.while_loop)
, i.e. basically jax.linear_tranpose(jax.jvp(jax.lax.while_loop))
.
In the case of this example, then as you note, we're already inside the JVP rule. Thus what we're actually doing is jax.linear_transpose(jax.lax.while_loop)
. This is an equally impossible operation to perform, it's just that the error message is only designed to help with the common error described in the previous paragraph.
Phew! Okay, what are the possible fixes? You've got a few possible options:
1) write a custom JAX primitive. This will allow you to define both JVP and transposition rules to your heart's content.
2) fix the issue I first described, then use jax.custom_vjp
. This won't allow you to perform forward-mode autodiff, though.
FWIW, this highlights a use-case for #17840, which adds support for jvp-of-custom_vjp
. If the approach there can be accepted + the PR finished off, then you'll be able to do use custom_vjp
without having to sacrifice support for JVPs.
@patrick-kidger Thank you, it was very educational! Hope your PR or something similar will be approved. I also want to have a way to define both vjp and jvp for large-scale inverse problems.
Thanks for all the information.
Thanks for all the help @patrick-kidger. I finally got around to working a bit more on this and running into what might be a related problem. I've fixed df
to not close over anything, and I'm using scan
instead of while_loop
so I think it should be fine to transpose.
Updated code:
def bounded_while_loop(condfun, bodyfun, init_val, bound):
"""While loop for bounded number of iterations, implemented using cond and scan."""
# could do some fancy stuff with checkpointing here like in equinox but the loops
# in quadax usually only do ~100 iterations max so probably not worth it.
def scanfun(state, *args):
return jax.lax.cond(condfun(state), bodyfun, lambda x: x, state), None
return jax.lax.scan(scanfun, init_val, None, bound)[0]
@functools.partial(jax.custom_jvp, nondiff_argnums=(0,1,2))
def dummy_integrate(fun, a, step, args):
def condfun(state):
x, f, fx = state
return abs(fx)/abs(f) > 1e-2
def bodyfun(state):
x, f, fx = state
x += step
fx = fun(x, *args)
f += fx
return x, f, fx
return bounded_while_loop(condfun, bodyfun, (a,fun(a, *args), np.inf), bound=100)[1]
@dummy_integrate.defjvp
def _dummy_integrate_jvp(fun, a, step, primals, tangents):
assert len(primals) == len(tangents) == 1
args = primals[0]
argsdot = tangents[0]
assert isinstance(args, tuple)
assert isinstance(argsdot, tuple)
assert len(args) == len(argsdot) == 1
assert args[0] == 1.2
assert a == 1
assert step == 0.01
f1 = dummy_integrate(fun, a, step, args)
# ignoring boundary terms, derivative of integral is integral of derivative
def df(x, vargs, vargsdot):
return jax.jvp(fun, (x, *vargs), (jnp.zeros_like(x), *vargsdot))[1]
f2 = dummy_integrate(df, a, step, (args, argsdot))
return f1, f2
fun = lambda x, c : jnp.exp(-c*x)
a = 1.
step = 0.01
c = 1.2
def bar(c):
return dummy_integrate(fun, a, step, (c,))
# this runs fine
jax.jacfwd(bar)(1.2)
# this fails with the error below
jax.jacrev(bar)(1.2)
I'm now getting an assertion error from _scan_transpose
about undefined primals:
JaxStackTraceBeforeTransformation: AssertionError
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
AssertionError Traceback (most recent call last)
Cell In[23], line 63
60 jax.jacfwd(bar)(1.2)
62 # this fails with the error below
---> 63 jax.jacrev(bar)(1.2)
File ~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/api.py:901, in jacrev.<locals>.jacfun(*args, **kwargs)
899 y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
900 tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
--> 901 jac = vmap(pullback)(_std_basis(y))
902 jac = jac[0] if isinstance(argnums, int) else jac
903 example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
[... skipping hidden 8 frame]
File ~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:707, in _scan_transpose(cts, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose, *args)
705 ires, _ = split_list(consts, [num_ires])
706 _, eres = split_list(xs, [sum(xs_lin)])
--> 707 assert not any(ad.is_undefined_primal(r) for r in ires)
708 assert not any(ad.is_undefined_primal(r) for r in eres)
710 carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
AssertionError:
any ideas?
Hmm. So 'undefined primals' actually refer to the tangents -- here, argsdot
-- whose values aren't available when transposing. Somehow the JVP rule of the scan is saving such an undefined primal as one of its residual values (those values that the forward pass saves for the backward pass).
If you figure this out then I'd be curious to know the answer!
Description
I'm working on a library for numerical quadrature in JAX, with derivatives defined via Leibniz rule. I've defined
custom_jvp
for my quadrature functions and it works fine in forward mode, but when trying reverse mode AD I get aNotImplementedError
The error:
As far as I understand, reverse mode should work here, since the jvp is defined in terms of calls to the primal function. Is there something I'm missing?
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response