Open Dong555 opened 12 months ago
Thanks for the report! This looks to be a bug in JAX itself: https://github.com/google/jax/issues/18244.
Sparse matrices still only have experimental support, so it looks like all the kinks haven't been worked out yet.
Since you're not batching over the operator, it looks like you can work around this by capturing the operator via closure, rather than passing it as an unbatched argument:
def _solve_beta(A: lx.AbstractLinearOperator, rhs: ArrayLike) -> Array:
solver = lx.NormalCG(rtol=1e-6, atol=1e-6)
out = jax.vmap(lambda b: lx.linear_solve(A, b, solver), in_axes=1)(rhs)
updated_beta = out.value.T
return updated_beta
Note that you do have to use jax.vmap
rather than eqx.filter_vmap
here -- the latter still ends up hitting a variant of the above bug in jax.experimental.sparse.BCOO
.
I'm encountering a ValueError when using the
vmap
functionality to map thelx.linear_solve
operation across multiple columns of a right-hand side matrix (RHS) with aSparseMatrixOperator
.I expect the
_solve_beta
function when passed aSparseMatrixOperator
and a multi-column RHS, to solve the linear system for each column of the RHS smoothly similarly to how it works when aMatrixLinearOperator
is used. While the linear solve operation works fine for individual columns of y with aSparseMatrixOperator
, it throws an error when attempting to solve for multiple columns usingvmap
, and produces the following error:ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification ... (trimmed for brevity)
Here is a minimal working example provided specific to this problem: