patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
365 stars 24 forks source link

Is lineax slower than the linear solver in JAX? #52

Open ToshiyukiBandai opened 1 year ago

ToshiyukiBandai commented 1 year ago

Hi, thank you for creating the awesome libraries in JAX. I started to use lineax recently and compared it with the linear solver in JAX. The code below resulted in 931 us for lineax and 171 us for jnp.linalg.solve. Is there anything wrong with my implementation? Or, should I just stick to jnp.linalg.solve? No way to use _gesv Fortran routine through lineax?

from jax import random
import jax.numpy as jnp
import lineax as lx

matrix_key, vector_key = random.split(random.PRNGKey(0))
matrix = random.normal(matrix_key, (10, 10))
vector = random.normal(vector_key, (10,))

operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector)

%timeit lx.linear_solve(operator, vector, solver=lx.LU())

%timeit jnp.linalg.solve(matrix, vector)
patrick-kidger commented 1 year ago

Looks like the overhead is from two things:

  1. Error-checking on the Lineax output. By default Lineax has an extra check that the return doesn't have NaNs etc., i.e. that the solve was successful. This can be disabled by passing linear_solve(..., throw=False).

  2. Pytree flattening/unflattening across JIT boundaries. matrix and vector are simpler PyTrees than operator and lx.LU().

With this benchmark I obtain identical performance:

import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import timeit

matrix_key, vector_key = jr.split(jr.PRNGKey(0))
matrix = jr.normal(matrix_key, (10, 10))
vector = jr.normal(vector_key, (10,))

@jax.jit
def solve_lineax(matrix, vector):
    operator = lx.MatrixLinearOperator(matrix)
    sol = lx.linear_solve(operator, vector, throw=False)
    return sol.value

@jax.jit
def solve_jax(matrix, vector):
    return jnp.linalg.solve(matrix, vector)

time_lineax = lambda: jax.block_until_ready(solve_lineax(matrix, vector))
time_jax = lambda: jax.block_until_ready(solve_jax(matrix, vector))

print(min(timeit.repeat(time_jax, number=1, repeat=10)))
print(min(timeit.repeat(time_lineax, number=1, repeat=10)))
ToshiyukiBandai commented 1 year ago

Hi Patrick,

I got the same results too. Thank you!