jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

``custom_jvp`` of ``while_loop`` failes on reverse mode. #18311

Open f0uriest opened 1 year ago

f0uriest commented 1 year ago

Description

I'm working on a library for numerical quadrature in JAX, with derivatives defined via Leibniz rule. I've defined custom_jvp for my quadrature functions and it works fine in forward mode, but when trying reverse mode AD I get a NotImplementedError


@functools.partial(jax.custom_jvp, nondiff_argnums=(0,))
def dummy_integrate(fun, a, step, *args):

    def condfun(state):
        x, f, fx = state
        return abs(fx)/abs(f) > 1e-2

    def bodyfun(state):
        x, f, fx = state
        x += step
        fx = fun(x, *args)
        f += fx
        return x, f, fx

    return jax.lax.while_loop(condfun, bodyfun, (a,fun(a, *args), np.inf))[1]

@dummy_integrate.defjvp   
def _dummy_integrate_jvp(fun, primals, tangents):
    a, step = primals[:2]
    args = primals[2:]
    adot, stepdot = tangents[:2]
    argsdot = tangents[2:]
    f1 = dummy_integrate(fun, *primals)

    # ignoring boundary terms, derivative of integral is integral of derivative
    def df(x, *args):
        return jax.jvp(fun, (x, *args), (jnp.zeros_like(x), *argsdot))[1]

    f2 = dummy_integrate(df, *primals)
    return f1,  f2

fun = lambda x, c : jnp.exp(-c*x)
a = 1.
step = 0.01
c = 1.2

def bar(c):
    return dummy_integrate(fun, a, step, c)

# this runs fine
jax.jacfwd(bar)(1.2)

# this fails with the error below
jax.jacrev(bar)(1.2)

The error:

File ~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:481, in JaxprTrace.post_process_custom_jvp_call(self, out_tracers, _)
    477 def post_process_custom_jvp_call(self, out_tracers, _):
    478   # This path should only be reachable if we expose a partial eval API
    479   # unrelated to autodiff, since we raise an error when differentiation with
    480   # respect to values over which a custom_jvp function closes is detected.
--> 481   raise NotImplementedError

NotImplementedError: 

As far as I understand, reverse mode should work here, since the jvp is defined in terms of calls to the primal function. Is there something I'm missing?

What jax/jaxlib version are you using?

No response

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

No response

adam-hartshorne commented 1 year ago

You might be interested in this,

https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py

f0uriest commented 1 year ago

I seem to get the same error when using scan, so it may not be unique to while loops.

ToshiyukiBandai commented 1 year ago

This is interesting. I have no problems using while_loop with custom JVP rule in my case but don't know what is happening in your case (I reproduced the error). Maybe using dummy_integrate twice in _dummy_integrate_jvp caused the error. I defined dummy_integrate2, which is exactly the same as dummy_integrate and replaced the second dummy_integrate with dummy_integrate2. I got the following error:

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.

This was interesting because I thought JAX does not differentiate _dummy_integrate_jvp function!

patrick-kidger commented 1 year ago

@f0uriest -- I think the issue you're seeing here is because the df function (defined inside the custom JVP rule) closes over argsdot, but is then passed as a nondiff_argnum to dummy_integrate (the one that returns f2).

When using a custom_jvp or custom_vjp rule, you must ensure any function-valued arguments don't close over any tracers. (Typically I'd recommend always making them global functions as a way to avoid any possibility of this.) You should rewrite things so that the argsdot tracers are passed as formal arguments.

@ToshiyukiBandai -- Right, you're now bumping into the second issue with the original code (and which f0uriest will hit after having fixed the above issue)!

Explaining this one requires knowing a bit about JAX internals. Buckle up, this gets a bit complicated.

JAX performs VJPs (backpropagation) by first figuring out what the tangent part of a JVP rule looks like, and then transposing it (loosely speaking, the "transpose" here is the "run it backwards" part of backpropagation). In terms of public JAX APIs, what this corresponds to is that jax.grad is built by applying jax.linear_transpose to the tangent inputs and outputs of jax.jvp.

Unfortunately, jax.lax.while_loop does not support transposition. Backpropagating through a while loop would require saving the result of every step. As it's a while loop, the number of steps is not known at compile time. That means the amount of memory needed is not known at compile time. And the XLA compiler only performs static memory allocation. Thus, no backpropagating through jax.lax.while_loop.

The error message you're seeing is to help catch the common case when jax.grad(jax.lax.while_loop), i.e. basically jax.linear_tranpose(jax.jvp(jax.lax.while_loop)).

In the case of this example, then as you note, we're already inside the JVP rule. Thus what we're actually doing is jax.linear_transpose(jax.lax.while_loop). This is an equally impossible operation to perform, it's just that the error message is only designed to help with the common error described in the previous paragraph.


Phew! Okay, what are the possible fixes? You've got a few possible options:

1) write a custom JAX primitive. This will allow you to define both JVP and transposition rules to your heart's content. 2) fix the issue I first described, then use jax.custom_vjp. This won't allow you to perform forward-mode autodiff, though.

FWIW, this highlights a use-case for #17840, which adds support for jvp-of-custom_vjp. If the approach there can be accepted + the PR finished off, then you'll be able to do use custom_vjp without having to sacrifice support for JVPs.

ToshiyukiBandai commented 1 year ago

@patrick-kidger Thank you, it was very educational! Hope your PR or something similar will be approved. I also want to have a way to define both vjp and jvp for large-scale inverse problems.

siliakao commented 1 year ago

Thanks for all the information.

f0uriest commented 4 months ago

Thanks for all the help @patrick-kidger. I finally got around to working a bit more on this and running into what might be a related problem. I've fixed df to not close over anything, and I'm using scan instead of while_loop so I think it should be fine to transpose.

Updated code:

def bounded_while_loop(condfun, bodyfun, init_val, bound):
    """While loop for bounded number of iterations, implemented using cond and scan."""
    # could do some fancy stuff with checkpointing here like in equinox but the loops
    # in quadax usually only do ~100 iterations max so probably not worth it.

    def scanfun(state, *args):
        return jax.lax.cond(condfun(state), bodyfun, lambda x: x, state), None

    return jax.lax.scan(scanfun, init_val, None, bound)[0]

@functools.partial(jax.custom_jvp, nondiff_argnums=(0,1,2))
def dummy_integrate(fun, a, step, args):

    def condfun(state):
        x, f, fx = state
        return abs(fx)/abs(f) > 1e-2

    def bodyfun(state):
        x, f, fx = state
        x += step
        fx = fun(x, *args)
        f += fx
        return x, f, fx

    return bounded_while_loop(condfun, bodyfun, (a,fun(a, *args), np.inf), bound=100)[1]

@dummy_integrate.defjvp   
def _dummy_integrate_jvp(fun, a, step, primals, tangents):
    assert len(primals) == len(tangents) == 1
    args = primals[0] 
    argsdot = tangents[0]
    assert isinstance(args, tuple)
    assert isinstance(argsdot, tuple)
    assert len(args) == len(argsdot) == 1
    assert args[0] == 1.2
    assert a == 1
    assert step == 0.01

    f1 = dummy_integrate(fun, a, step, args)

    # ignoring boundary terms, derivative of integral is integral of derivative
    def df(x, vargs, vargsdot):
        return jax.jvp(fun, (x, *vargs), (jnp.zeros_like(x), *vargsdot))[1]

    f2 = dummy_integrate(df, a, step, (args, argsdot))
    return f1,  f2

fun = lambda x, c : jnp.exp(-c*x)
a = 1.
step = 0.01
c = 1.2

def bar(c):
    return dummy_integrate(fun, a, step, (c,))

# this runs fine
jax.jacfwd(bar)(1.2)

# this fails with the error below
jax.jacrev(bar)(1.2)

I'm now getting an assertion error from _scan_transpose about undefined primals:

JaxStackTraceBeforeTransformation: AssertionError

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

AssertionError                            Traceback (most recent call last)
Cell In[23], line 63
     60 jax.jacfwd(bar)(1.2)
     62 # this fails with the error below
---> 63 jax.jacrev(bar)(1.2)

File ~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/api.py:901, in jacrev.<locals>.jacfun(*args, **kwargs)
    899   y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
    900 tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
--> 901 jac = vmap(pullback)(_std_basis(y))
    902 jac = jac[0] if isinstance(argnums, int) else jac
    903 example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args

    [... skipping hidden 8 frame]

File ~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:707, in _scan_transpose(cts, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose, *args)
    705 ires, _ = split_list(consts, [num_ires])
    706 _, eres = split_list(xs, [sum(xs_lin)])
--> 707 assert not any(ad.is_undefined_primal(r) for r in ires)
    708 assert not any(ad.is_undefined_primal(r) for r in eres)
    710 carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])

AssertionError: 

any ideas?

patrick-kidger commented 4 months ago

Hmm. So 'undefined primals' actually refer to the tangents -- here, argsdot -- whose values aren't available when transposing. Somehow the JVP rule of the scan is saving such an undefined primal as one of its residual values (those values that the forward pass saves for the backward pass).

If you figure this out then I'd be curious to know the answer!