jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

`custom_root` with integer aux output broken in 0.4.34 #24295

Open f0uriest opened 2 weeks ago

f0uriest commented 2 weeks ago

Description

I have a basic root finder like this:

import jax
import jax.numpy as jnp

def root(
    fun,
    x0,
    jac=None,
    args=(),
    tol=1e-6,
    maxiter=20,
):
    """Find x where fun(x, *args) == 0."""

    jac2 = lambda x: jnp.atleast_2d(jax.jacfwd(fun)(x, *args))
    res = lambda x: jnp.atleast_1d(fun(x, *args)).flatten()

    def solve(resfun, guess):
        def condfun(state):
            xk1, fk1, k1 = state
            return (k1 < maxiter) & (jnp.dot(fk1, fk1) > tol**2)

        def bodyfun(state):
            xk1, fk1, k1 = state
            J = jac2(xk1)
            d = jnp.linalg.solve(J, fk1)
            xk2 = xk1 - d
            fk2 = resfun(xk2)
            return xk2, fk2, k1 + 1

        state = (
            jnp.atleast_1d(jnp.asarray(guess)), # x
            jnp.atleast_1d(resfun(guess)), # residual
            0, # number of iterations
        )
        state = jax.lax.while_loop(condfun, bodyfun, state)
        return state[0], state[1:]

    def tangent_solve(g, y):
        A = jnp.atleast_2d(jax.jacfwd(g)(y))
        return jnp.linalg.solve(A, jnp.atleast_1d(y))

    x, aux = jax.lax.custom_root(
        res, x0, solve, tangent_solve, has_aux=True
    )
    return x, aux

which returns both the root and the value of f at the root, and the number of steps taken. Previously this worked fine, with has_aux=True for custom_root. However, v0.4.34 seems to have changed something in the way tangents of non-differentiable values get propagated (#24262).

Now running the following

def fun(x, a):
    return a*x - 1

def find_root_fun(a):
    x0 = jnp.array([0.,])
    xk, aux = root(fun, x0, args=(a,))
    return xk, aux

jax.jacfwd(find_root_fun, has_aux=True)(a)

gives the following:

--------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[26], line 59
     56     xk, aux = root(fun, x0, args=(a,))
     57     return xk, aux
---> 59 jax.jacfwd(find_root_fun, has_aux=True)(a)

File [~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/api.py:584](http://127.0.0.1:8888/lab/workspaces/auto-F/tree/SCHOOL/Princeton/PPPL/DESC/local/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/api.py#line=583), in jacfwd.<locals>.jacfun(*args, **kwargs)
    582 else:
    583   pushfwd: Callable = partial(_jvp, f_partial, dyn_args, has_aux=True)
--> 584   y, jac, aux = vmap(pushfwd, out_axes=(None, -1, None))(_std_basis(dyn_args))
    585 tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
    586 example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args

    [... skipping hidden 5 frame]

Cell In[26], line 56, in find_root_fun(a)
     54 def find_root_fun(a):
     55     x0 = jnp.array([0.,])
---> 56     xk, aux = root(fun, x0, args=(a,))
     57     return xk, aux

Cell In[26], line 43, in root(fun, x0, jac, args, tol, maxiter)
     40     A = jnp.atleast_2d(jax.jacfwd(g)(y))
     41     return jnp.linalg.solve(A, jnp.atleast_1d(y))
---> 43 x, (f, niter) = jax.lax.custom_root(
     44     res, x0, solve, tangent_solve, has_aux=True
     45 )
     46 return x,  (jnp.sum(jnp.abs(f)), niter)

    [... skipping hidden 7 frame]

File [~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/custom_derivatives.py:351](http://127.0.0.1:8888/lab/workspaces/auto-F/tree/SCHOOL/Princeton/PPPL/DESC/local/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/custom_derivatives.py#line=350), in _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args)
    344     msg = ("Custom JVP rule must produce primal and tangent outputs with "
    345            "corresponding shapes and dtypes, but got:\n{}")
    346     disagreements = (
    347         f"  primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}"
    348         for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out)
    349         if av_et != av_t)
--> 351     raise TypeError(msg.format('\n'.join(disagreements)))
    352 yield primals_out + tangents_out, (out_tree, primal_avals)

TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got:
  primal int64[] with tangent int64[], expecting tangent ShapedArray(float0[])

I get the same error if I drop the aux output in find_root_fun and leave out the has_aux when calling jacfwd. The only way I've found to avoid the error is to remove the aux from the innermost solve and set has_aux=False on custom_root

Is this expected? I assumed having integer valued aux output was kind of the point of the has_aux option?

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.34 jaxlib: 0.4.34 numpy: 1.24.4 python: 3.10.11 (main, May 16 2023, 00:28:57) [GCC 11.2.0] jax.devices (8 total, 8 local): [CpuDevice(id=0) CpuDevice(id=1) ... CpuDevice(id=6) CpuDevice(id=7)] process_count: 1 platform: uname_result(system='Linux', node='Discovery', release='5.15.0-119-generic', version='#129~20.04.1-Ubuntu SMP Wed Aug 7 13:07:13 UTC 2024', machine='x86_64')

YigitElma commented 1 week ago

I think we can simply get rid of the error by changing the number of iterations dtype to float like,

        ...
        state = (
            jnp.atleast_1d(jnp.asarray(guess)), # x
            jnp.atleast_1d(resfun(guess)), # residual
            0.0, # number of iterations
        )
        ...

As long as we don't use the derivative of the number of iterations later in the code, I believe this shouldn't change the differentiation of root.

That said, this is probably not how you want to implement it. A more proper way could be writing custom_jvp for root and setting the derivative of niter to SymbolicZeros, but this is more cumbersome.