jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.44k stars 2.8k forks source link

Reverse-mode AD failing with GMRES in custom JVP #5309

Open mannsean opened 3 years ago

mannsean commented 3 years ago

@romanodev and I are working on topology optimization using JAX’s newly released GMRES function. To avoid differentiating through the iterative solver, we are using @jax.custom_jvp with implicit functions. However, I am having trouble with reverse-mode AD, while both forward and reverse-mode work if we replace GMRES with something like np.linalg.solve. Here is a minimal example of the issue (to keep things tidy, I just have the identity function), while the full example is in the colab (https://colab.research.google.com/drive/1fIZgoB2zdMpErqi-q54k9CRsyUfIFbfp?usp=sharing). Any idea what is happening here? Thanks!

from jax import numpy as np
import jax

def f(x, p):
    """
    Minimal problem that can be solved by
    setting x = p for some given p
    """
    return x - p

key = jax.random.PRNGKey(0)
p_test = jax.random.normal(key, (10,))

@jax.custom_jvp
def solve_gmres(p):
    """
    Trivial solution for x
    """
    return p

@solve_gmres.defjvp
def solve_gmres_jvp(primals, tangents):
    """
    Custom JVP using GMRES
    """
    p, = primals
    dp, = tangents
    x = solve_gmres(p)
    f_x, f_p = jax.jacfwd(f, argnums=(0, 1))(x, p)
    dx, _ = jax.scipy.sparse.linalg.gmres(f_x, -f_p @ dp) # Only difference
    return x, dx

jax.jacfwd(solve_gmres)(p_test) # works
jax.jacrev(solve_gmres)(p_test) # fails with error:
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-10-5ecdbc205337> in <module>()
----> 1 jax.jacrev(solve_gmres)(p_test)
20 frames
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in abstractify(x)
    174     aval_fn = pytype_aval_mappings.get(typ)
    175     if aval_fn: return aval_fn(x)
--> 176   raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
    177 
    178 def _make_abstract_python_scalar(typ, _):
TypeError: Argument 'UndefinedPrimal(ShapedArray(float32[]))' of type '<class 'jax.interpreters.ad.UndefinedPrimal'>' is not a valid JAX type
shoyer commented 3 years ago

It's great to see this use case! Topology optimization is very near to my heart and is exactly the same of thing we wanted to enable with GMRES.

It's not immediately obvious to me what is going on, but hopefully we can figure it out. It does look like a bug -- I think this is supposed to work (or least give a better error message).

I am curious why you need a custom JVP here. At first glance, what you're doing looks quite similar to implementing the implicit function theorem, which is exactly what lax.custom_root does. Have you tried using custom_root? I guess this could save you a few lines of code, but unfortunately it is also broken with GMRES, with a very similar error message (https://github.com/google/jax/pull/5321).

mannsean commented 3 years ago

Great! Thanks for the suggestion -- I wrote the custom JVP with IFT in mind, but really wasn't aware of lax.custom_root. It feels custom_root is a little bit less flexible, hence messier to fit into our use case: keyword arguments aren't allowed, and from some quick tests it seems like I'm no longer allowed to use native Python loops and logic for the solve and tangent_solve supplied to custom_root (using lax.while_loop seems to slow the code down). Another thing is the linearization of f -- as we implemented our own (sparse) Jacobian for f, computing g is unnecessary. Would you say using custom_root brings any speedup relative to custom_jvp?

Hope we can fix the bug, and thank you for developing JAX!

shoyer commented 3 years ago

custom_root is just a particular implementation of the IFT, also written using custom_jvp. There is nothing wrong with rolling your own version, and indeed that might be desirable in some cases, e.g., if you want to use Python control flow.

mattjj commented 3 years ago

@froystig I co-assigned you just now b/c this is relevant to stuff we've been talking about.