Open WangHaiYang874 opened 1 year ago
Curiously it seems to work fine if you use cg
or bicgstab
for either a
or b
(or both) so it seems to be limited to gmres
not other iterative linear solvers. Under the hood both cg
and bicgstab
use similar infrastructure that gmres
doesn't.
You may like to try Lineax, which is our new more-comprehensive solution for linear solvers. I've just checked and it appears to handle this case without issues.
Assuming this is still a bug, I did some digging and I figured out that the issue is probably caused when the solver calls linear_transpose at the end of the solve. For some reason, as opposed to every other time this function is called, when it generates the jaxpr (line 2282 in api.py), despite having, as far as I can tell, the same parameters, it produces a jaxpr with a non-zero length of equations (len(jaxpr.eqns) > 0
) and produces invars that are not in the constvars. This produces a problem further down the line when backwards_pass is called (line 181 in ad.py). This method generates a mapping of sorts between the constvars of the provided jaxpr and some other constants passed into the function itself. Later, the method iterates over every equation in jaxpr.eqns, and as part of that loop, it tries to read the invars from the mapping it generated earlier. If the invar is not defined in the mapping, it gets an UndefinedPrimal by default. Later, when concrete_aval is called on the values, it tries to concrete_aval on the UndefinedPrimal, which fails because apparently an UndefinedPrimal does not have the "__jax_array__"
attribute, which causes the method to error out. I am not sure if UndefinedPrimals are supposed to have that attribute (they probably aren't) and I have no idea why the jaxpr generation method just suddenly decided to throw in a bunch of extra invars and constvars. Every other call I caught made a jaxpr with 0 constvars, 1 invar (a), and 0 equations, which didn't cause issues later in the backwards_pass method since it tried to iterate over the equations and immediately stopped since there aren't any. Once again, as far as I can tell, the inputs into the jaxpr generation method are identical each time.
Alright, the above is not true. The error happens when the function a is called on the pvals, but why is a returning a jaxpr anyway? Shouldn't it just return a single value?
I noticed this when working on some other stuff, but it seems like its a more general problem using gmres
with fori_loop
in the linear function:
def bfun(x):
out = jnp.zeros_like(x)
def body(i, out):
out = out.at[i].set(3*x[i])
return out
return jax.lax.fori_loop(0, x.size, body, out)
jax.scipy.sparse.linalg.gmres(bfun, jnp.arange(5).astype(float))
gives the following assertion error:
AssertionError Traceback (most recent call last)
Cell In[108], line 8
5 return out
6 return jax.lax.fori_loop(0, x.size, body, out)
----> 8 jax.scipy.sparse.linalg.gmres(bfun, jnp.arange(5).astype(float))
File ~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/scipy/sparse/linalg.py:704, in gmres(A, b, x0, tol, atol, restart, maxiter, M, solve_method)
702 def _solve(A, b):
703 return _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M, gmres_func)
--> 704 x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
706 failed = jnp.isnan(_norm(x))
707 info = jnp.where(failed, x=-1, y=0)
[... skipping hidden 12 frame]
File ~/miniconda3/envs/desc/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:718, in _scan_transpose(reduce_axes, cts, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, *args)
716 ires, _ = split_list(consts, [num_ires])
717 _, eres = split_list(xs, [sum(xs_lin)])
--> 718 assert not any(ad.is_undefined_primal(r) for r in ires)
719 assert not any(ad.is_undefined_primal(r) for r in eres)
721 carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
AssertionError:
Description
I am building an iterative solver. Naturally some of my linear operators are defined iteratively using gmres. A simple example would be
On top of this linear operator I can define another linear operator
However, jax would throw an error message whenever i am trying to call
b
.Here is the full script:
And here is the error message
What jax/jaxlib version are you using?
jax v0.4.12
Which accelerator(s) are you using?
cpu
Additional system info
No response
NVIDIA GPU info
No response