Open vboussange opened 4 days ago
So (a) I think you've made a few mistakes in the benchmarking, and (b) most Lineax/Optimistix/Diffrax routines all finish with an option to throw a runtime error if things have gone wrong, and this adds a measurable amount of overhead on microbenchmarks such as this. This can be disabled with throw=False
.
So adjusting things a little, I get exactly comparable results between the two approaches.
import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import timeit
@jax.jit
def linalg_solve(A, B):
x = jnp.linalg.solve(A, B)
error = jnp.linalg.norm(B - (A @ x))
return x, error
@jax.jit
def lineax_solve(A, B):
operator = lx.MatrixLinearOperator(A)
def solve_single(b):
x = lx.linear_solve(operator, b, throw=False).value
return x
x = jax.vmap(solve_single, in_axes=1, out_axes=1)(B)
error = jnp.linalg.norm(B - (A @ x))
return x, error
def benchmark(method, func):
times = timeit.repeat(func, number=1, repeat=10)
_, error = func()
print(f"{method} solve error: {error:2e}")
print(f"{method} min time: {min(times)}\n")
N = 20
key = jr.PRNGKey(0)
A = jr.uniform(key, (N, N))
B = jnp.eye(N, N)
linalg_solve(A, B)
lineax_solve(A, B)
benchmark("linalg.solve", lambda: jax.block_until_ready(linalg_solve(A, B)))
benchmark("lineax", lambda: jax.block_until_ready(lineax_solve(A, B)))
# linalg.solve solve error: 7.080040e-06
# linalg.solve min time: 4.237500252202153e-05
#
# lineax solve error: 7.080040e-06
# lineax min time: 3.9375037886202335e-05
Notable changes here:
throw=False
to disable Lineax's checking for success (and just silently returning NaNs if things go wrong).jax.block_until_ready
.min
with repeat=10
, rather than the mean, over the evaluation times. As benchmarking noise is one-sided then this is usually the correct aggregation method for microbenchmarks.FWIW I've also trimmed out the use of state
and the explicit lineax.LU()
solver, as the former is done already inside the solve and the latter is the default.
Excellent, thanks for the details!
FWIW I've also trimmed out the use of state and the explicit lineax.LU() solver, as the former is done already inside the solve and the latter is the default.
I am surprised that the vmap
is not triggering multiple internal init
?
I am surprised that the vmap is not triggering multiple internal init?
init
is called only on the non-vmap'd input A, so it won't be vmap'd.
Hey there, Some native JAX solvers such as
jnp.linalg.solve
andjax.scipy.sparse.linalg.gmres
nicely support batch mode, where the right hand side of the system $A X = B$ is a $n \times m$ matrix. What is the best approach to efficiently reproduce this behaviour withlineax
?I made a benchmark using
vmap
andlineax
, but this approach is is 4x slower: