Open ToshiyukiBandai opened 1 year ago
Looks like the overhead is from two things:
Error-checking on the Lineax output. By default Lineax has an extra check that the return doesn't have NaNs etc., i.e. that the solve was successful. This can be disabled by passing linear_solve(..., throw=False)
.
Pytree flattening/unflattening across JIT boundaries. matrix
and vector
are simpler PyTrees than operator
and lx.LU()
.
With this benchmark I obtain identical performance:
import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import timeit
matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 10))
vector = jr.normal(vector_key, (10,))
@jax.jit
def solve_lineax(matrix, vector):
operator = lx.MatrixLinearOperator(matrix)
sol = lx.linear_solve(operator, vector, throw=False)
return sol.value
@jax.jit
def solve_jax(matrix, vector):
return jnp.linalg.solve(matrix, vector)
time_lineax = lambda: jax.block_until_ready(solve_lineax(matrix, vector))
time_jax = lambda: jax.block_until_ready(solve_jax(matrix, vector))
print(min(timeit.repeat(time_jax, number=1, repeat=10)))
print(min(timeit.repeat(time_lineax, number=1, repeat=10)))
Hi Patrick,
I got the same results too. Thank you!
Hi, thank you for creating the awesome libraries in JAX. I started to use lineax recently and compared it with the linear solver in JAX. The code below resulted in 931 us for lineax and 171 us for jnp.linalg.solve. Is there anything wrong with my implementation? Or, should I just stick to jnp.linalg.solve? No way to use _gesv Fortran routine through lineax?