jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.43k stars 2.79k forks source link

nested gmres does not work #16441

Open WangHaiYang874 opened 1 year ago

WangHaiYang874 commented 1 year ago

Description

I am building an iterative solver. Naturally some of my linear operators are defined iteratively using gmres. A simple example would be

def a(x):
    # this function is essentially the linear operator x -> 2*x
    return gmres(lambda i: i, 2 * x)[0]

On top of this linear operator I can define another linear operator

def b(x):
    # this function is essentially the inverse of a
    return gmres(a, x)[0]

However, jax would throw an error message whenever i am trying to call b.

Here is the full script:

import jax
from jax import numpy as np
from functools import partial

gmres = partial(jax.scipy.sparse.linalg.gmres, tol=1e-3, restart=10, solve_method='batched')

def a(x):
    # this function is essentially the linear operator x-> 2*x
    return gmres(lambda i: i, 2 * x)[0]

def b(x):
    # this function is essentially the inverse of a
    # but it does not work...
    return gmres(a, x)[0]

if __name__ == '__main__':
    t = np.ones((10,))
    b(t)

And here is the error message

TypeError: Value UndefinedPrimal(ShapedArray(float32[])) with type <class 'jax._src.interpreters.ad.UndefinedPrimal'> is not a valid JAX type

What jax/jaxlib version are you using?

jax v0.4.12

Which accelerator(s) are you using?

cpu

Additional system info

No response

NVIDIA GPU info

No response

f0uriest commented 1 year ago

Curiously it seems to work fine if you use cg or bicgstab for either a or b (or both) so it seems to be limited to gmres not other iterative linear solvers. Under the hood both cg and bicgstab use similar infrastructure that gmres doesn't.

patrick-kidger commented 1 year ago

You may like to try Lineax, which is our new more-comprehensive solution for linear solvers. I've just checked and it appears to handle this case without issues.

ThatGuyDavid09 commented 1 year ago

Assuming this is still a bug, I did some digging and I figured out that the issue is probably caused when the solver calls linear_transpose at the end of the solve. For some reason, as opposed to every other time this function is called, when it generates the jaxpr (line 2282 in api.py), despite having, as far as I can tell, the same parameters, it produces a jaxpr with a non-zero length of equations (len(jaxpr.eqns) > 0) and produces invars that are not in the constvars. This produces a problem further down the line when backwards_pass is called (line 181 in ad.py). This method generates a mapping of sorts between the constvars of the provided jaxpr and some other constants passed into the function itself. Later, the method iterates over every equation in jaxpr.eqns, and as part of that loop, it tries to read the invars from the mapping it generated earlier. If the invar is not defined in the mapping, it gets an UndefinedPrimal by default. Later, when concrete_aval is called on the values, it tries to concrete_aval on the UndefinedPrimal, which fails because apparently an UndefinedPrimal does not have the "__jax_array__"attribute, which causes the method to error out. I am not sure if UndefinedPrimals are supposed to have that attribute (they probably aren't) and I have no idea why the jaxpr generation method just suddenly decided to throw in a bunch of extra invars and constvars. Every other call I caught made a jaxpr with 0 constvars, 1 invar (a), and 0 equations, which didn't cause issues later in the backwards_pass method since it tried to iterate over the equations and immediately stopped since there aren't any. Once again, as far as I can tell, the inputs into the jaxpr generation method are identical each time.

ThatGuyDavid09 commented 1 year ago

Alright, the above is not true. The error happens when the function a is called on the pvals, but why is a returning a jaxpr anyway? Shouldn't it just return a single value?

f0uriest commented 1 year ago

I noticed this when working on some other stuff, but it seems like its a more general problem using gmres with fori_loop in the linear function:

def bfun(x):
    out = jnp.zeros_like(x)
    def body(i, out):
        out = out.at[i].set(3*x[i])
        return out
    return jax.lax.fori_loop(0, x.size, body, out)

jax.scipy.sparse.linalg.gmres(bfun, jnp.arange(5).astype(float))

gives the following assertion error:


AssertionError                            Traceback (most recent call last)
Cell In[108], line 8
      5         return out
      6     return jax.lax.fori_loop(0, x.size, body, out)
----> 8 jax.scipy.sparse.linalg.gmres(bfun, jnp.arange(5).astype(float))

File ~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/scipy/sparse/linalg.py:704, in gmres(A, b, x0, tol, atol, restart, maxiter, M, solve_method)
    702 def _solve(A, b):
    703   return _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M, gmres_func)
--> 704 x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
    706 failed = jnp.isnan(_norm(x))
    707 info = jnp.where(failed, x=-1, y=0)

    [... skipping hidden 12 frame]

File ~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:718, in _scan_transpose(reduce_axes, cts, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, *args)
    716 ires, _ = split_list(consts, [num_ires])
    717 _, eres = split_list(xs, [sum(xs_lin)])
--> 718 assert not any(ad.is_undefined_primal(r) for r in ires)
    719 assert not any(ad.is_undefined_primal(r) for r in eres)
    721 carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])

AssertionError: