Open Azercoco opened 2 years ago
You are currently also measuring jit compile time in your code. Maybe try something like this
solve_jax_aot = jit(solve_jax).lower(A, b).compile()
...
for i in range(...):
x, info = solve_jax_aot(A, b)
...
It's also helpful to set a maxiter
in jax.scipy.sparse.linalg.gmres
.
Also don't forget block_until_ready()
; see https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code for more information.
For what it's worth, this is the result of your benchmark, slightly modified, and run on a Colab CPU runtime:
import jax
import scipy
from jax import jit
def solve_jax(A, b):
return jax.scipy.sparse.linalg.gmres(lambda v: A @ v, b,solve_method='batched', atol=1e-5)
solve_jax_jit = jit(solve_jax)
A = scipy.random.rand(30, 30)
b = scipy.random.rand(30)
_ = solve_jax_jit(A, b)[0].block_until_ready()
%timeit scipy.sparse.linalg.gmres(A, b, atol=1e-5, restart=20)
# 10 loops, best of 5: 127 ms per loop
%timeit solve_jax_jit(A, b)[0].block_until_ready()
# 100 loops, best of 5: 11.3 ms per loop
Compilation time does not seem to cause the issue ... I just ran your code and got :
121 ms ± 6.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.15 s ± 407 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
but at least the problem seems to come from my installation.
Can you say more about how you're running this? What versions of jax and jaxlib do you have installed? Are you on CPU or GPU? With or without X64 enabled?
I'm using
I installed the Windows build following these instruction (using the jaxlib-0.1.75+cuda111-cp39-none-win_amd64.whl wheel) so maybe the issue is within this build. (https://github.com/cloudhan/jax-windows-builder)
This is a known issue with JAX's iterative methods on GPUs. XLA's loops on GPU sync back to the host CPU on each loop iteration, so they are slow if the function being solved is quick to evaluate. I actually made a benchmark for this issue that we passed off to the XLA GPU team.
The work-around is either to run this sort of computation on the CPU instead (e.g., by copying all arrays explicitly to the CPU with jax.device_put
). Or you could also try using TPUs, for which XLA is able to run loops entirely on device.
We could also consider writing/leveraging custom CUDA kernels or iterative solvers on GPUs, but that would be a major undertaking and likely would have some disadvantages (like losing pytree support).
having the same problem on CPU,
import numpy as np
import scipy.sparse as scipy_sparse
from scipy.sparse.linalg import lgmres as scipy_gmres
import time
def create_sparse_matrix(n, density=0.01):
A = scipy_sparse.random(n, n, density=density, format='csr')
return A
def benchmark_scipy_gmres(A_scipy, b, tol=1e-5, maxiter=1000):
start_time = time.time()
x, info = scipy_gmres(A_scipy, b)
end_time = time.time()
# Check convergence
residual = np.linalg.norm(A_scipy @ x - b)
converged = info == 0 # GMRES returns 0 if converged
return x, end_time - start_time, converged, float(residual), int(info)
def run_benchmark(n_values, num_runs=1):
results = []
for n in n_values:
scipy_times = []
convergence_rates = []
residuals = []
infos = []
for _ in range(num_runs):
A = create_sparse_matrix(n)
b = np.random.rand(n)
_, scipy_time, converged, residual, info = benchmark_scipy_gmres(A, b)
scipy_times.append(scipy_time)
convergence_rates.append(float(converged))
residuals.append(residual)
infos.append(info)
results.append({
'n': n,
'scipy_avg_time': np.mean(scipy_times),
'convergence_rate': np.mean(convergence_rates),
'avg_residual': np.mean(residuals),
'avg_info': np.mean(infos)
})
print(f"Completed benchmarks for n={n}")
return results
if __name__ == "__main__":
n_values = [100, 500, 1000, 2000]
results = run_benchmark(n_values)
print("\nBenchmark Results:")
print("n\tSciPy GMRES (s)\tConv. Rate\tAvg. Residual\tAvg. Info")
for result in results:
print(f"{result['n']}\t{result['scipy_avg_time']:.4f}\t\t{result['convergence_rate']:.2f}\t\t{result['avg_residual']:.2e}\t\t{result['avg_info']:.2f}")
vs
import numpy as np
import scipy.sparse as scipy_sparse
import jax
import jax.numpy as jnp
from jax.experimental import sparse
from jax.scipy.sparse.linalg import gmres as jax_gmres
import time
def create_sparse_matrix(n, density=0.01):
A = scipy_sparse.random(n, n, density=density, format='csr')
return A
def scipy_to_jax_sparse(scipy_sparse_mat):
return sparse.BCSR.from_scipy_sparse(scipy_sparse_mat)
def benchmark_jax_gmres(A_scipy, b, tol=1e-5, maxiter=1000):
A_jax = scipy_to_jax_sparse(A_scipy)
b_jax = jnp.array(b)
@jax.jit
def jax_gmres_solve(A, b):
x, info = jax_gmres(A, b)
return x, info
# Compile the function
_ = jax_gmres_solve(A_jax, b_jax)
# Benchmark the compiled function
start_time = time.time()
x, info = jax_gmres_solve(A_jax, b_jax)
x = x.block_until_ready() # Ensure computation is complete
end_time = time.time()
# Check convergence
residual = jnp.linalg.norm(A_jax @ x - b_jax)
converged = info == 0 # GMRES returns 0 if converged
return x, end_time - start_time, converged, float(residual), int(info)
def run_benchmark(n_values, num_runs=1):
results = []
for n in n_values:
jax_times = []
convergence_rates = []
residuals = []
infos = []
for _ in range(num_runs):
A = create_sparse_matrix(n)
b = np.random.rand(n)
_, jax_time, converged, residual, info = benchmark_jax_gmres(A, b)
jax_times.append(jax_time)
convergence_rates.append(float(converged))
residuals.append(residual)
infos.append(info)
results.append({
'n': n,
'jax_avg_time': np.mean(jax_times),
'convergence_rate': np.mean(convergence_rates),
'avg_residual': np.mean(residuals),
'avg_info': np.mean(infos)
})
print(f"Completed benchmarks for n={n}")
return results
if __name__ == "__main__":
n_values = [100, 500, 1000, 2000]
results = run_benchmark(n_values)
print("\nBenchmark Results:")
print("n\tJAX GMRES (s)\tConv. Rate\tAvg. Residual\tAvg. Info")
for result in results:
print(f"{result['n']}\t{result['jax_avg_time']:.4f}\t\t{result['convergence_rate']:.2f}\t\t{result['avg_residual']:.2e}\t\t{result['avg_info']:.2f}")
@michelkluger — It's hard to say much here since you didn't include the results of your benchmarks, but some high level thoughts: All of the sparse support in JAX is experimental so your milage may vary on the actual performance of these APIs. But, there have been some changes in how XLA handles CPU iterations, so I'd be interested to know how the performance compares with and without the XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
environment variable set to know if there is a specific regression here.
I tried to create a new script, to guarantee that comparison is more fair, same matrix and everything
import numpy as np
import scipy.sparse as sp
from scipy.sparse.linalg import gmres as scipy_gmres
import jax
import jax.numpy as jnp
from jax.experimental import sparse as jax_sparse
from jax.scipy.sparse.linalg import gmres as jax_gmres
import time
import pandas as pd
# Constants
MATRIX_DENSITY = 0.1
TOLERANCE = 1e-5
MAX_ITERATIONS = 1000
NUM_RUNS = 1
N_VALUES = [100, 200, 500, 1000]
OUTPUT_FILE = "gmres_benchmark_results.csv"
def create_sparse_matrix(n, density=MATRIX_DENSITY):
A = sp.random(n, n, density=density, format='csr')
return A
def benchmark_scipy_gmres(A, b, tol=TOLERANCE, maxiter=MAX_ITERATIONS):
start_time = time.time()
try:
x, info = scipy_gmres(A, b, atol=tol, maxiter=maxiter, restart=min(maxiter, A.shape[0]))
end_time = time.time()
residual = np.linalg.norm(A @ x - b)
return end_time - start_time, info == 0, residual
except Exception as e:
print(f"SciPy GMRES error: {str(e)}")
return np.nan, False, np.nan
@jax.jit
def jax_gmres_solve(A, b, tol=TOLERANCE, maxiter=MAX_ITERATIONS):
x, info = jax_gmres(A, b, tol=tol, maxiter=maxiter, restart=min(maxiter, A.shape[1]))
return x, info
def benchmark_jax_gmres(A, b, tol=TOLERANCE, maxiter=MAX_ITERATIONS):
A_jax = jax_sparse.BCOO.from_scipy_sparse(A)
b_jax = jnp.array(b)
# Warm-up run
_ = jax_gmres_solve(A_jax, b_jax)
start_time = time.time()
try:
x, info = jax_gmres_solve(A_jax, b_jax)
x = x.block_until_ready()
end_time = time.time()
residual = jnp.linalg.norm(A_jax @ x - b_jax)
return end_time - start_time, info == 0, float(residual)
except Exception as e:
print(f"JAX GMRES error: {str(e)}")
return np.nan, False, np.nan
def run_benchmark(n_values, num_runs=NUM_RUNS):
results = []
for n in n_values:
print(f"Running benchmark for n={n}")
for _ in range(num_runs):
A = create_sparse_matrix(n)
b = np.random.rand(n)
scipy_time, scipy_conv, scipy_res = benchmark_scipy_gmres(A, b)
jax_time, jax_conv, jax_res = benchmark_jax_gmres(A, b)
results.append({
'n': n,
'scipy_time': scipy_time,
'scipy_conv': scipy_conv,
'scipy_res': scipy_res,
'jax_time': jax_time,
'jax_conv': jax_conv,
'jax_res': jax_res,
'speedup': scipy_time / jax_time if jax_time > 0 else np.nan
})
return pd.DataFrame(results)
if __name__ == "__main__":
results_df = run_benchmark(N_VALUES)
summary = results_df.groupby('n').mean()
summary = summary[['scipy_time', 'jax_time', 'speedup', 'scipy_conv', 'jax_conv', 'scipy_res', 'jax_res']]
summary.columns = ['SciPy Time (s)', 'JAX Time (s)', 'Speedup', 'SciPy Conv. Rate', 'JAX Conv. Rate', 'SciPy Residual', 'JAX Residual']
print("\nBenchmark Results Summary:")
pd.set_option('display.float_format', lambda x: f"{x:.4f}")
print(summary)
summary.to_csv(OUTPUT_FILE)
print(f"\nResults have been saved to '{OUTPUT_FILE}'")
@michelkluger — It's hard to say much here since you didn't include the results of your benchmarks, but some high level thoughts: All of the sparse support in JAX is experimental so your milage may vary on the actual performance of these APIs. But, there have been some changes in how XLA handles CPU iterations, so I'd be interested to know how the performance compares with and without the
XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
environment variable set to know if there is a specific regression here.
I added a new script for more fair comparison, I am surprised, because from all I read, Jax was meant to be so much faster, but I am very new to that, so it is very much possible that I made some mistakes, be happy to learn from that
Thanks! I ran this script locally and I can confirm that the performance with XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
is comparable to scipy, but without we have a major performance regression.
Pinging @ezhulenev who has been looking into related performance issues.
@michelkluger — For your specific use case, I'd recommend setting that environment variable for now.
@michelkluger — For your specific use case, I'd recommend setting that environment variable for now.
I ran with the flag,
os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
tested with
# get XLA_FLAGS env variable
def get_xla_flags():
try:
return os.environ["XLA_FLAGS"]
except KeyError:
return None
get_xla_flags()
'--xla_cpu_use_thunk_runtime=false'
What happens if you replace the sparse.BCSR
matrix with a dense matrix? The jax.experimental.sparse
code is just a reference implementation, and is quite slow even for simple operations like sparse matmul. I wouldn't be surprised if slow sparse matmul is driving the performance here.
Hello, I used the following script to compare the performance of the jax gmres solver and the one from scipy :
Here was the result :
Nearly a 20x factor of difference ! I tried to play around with the gmres parameters but it did not change the time ratio scipy/jax.
I am using a GPU backend (could this explain the performence difference ?) on Windows (my GPU is a GTX 1650) and I have Cuda 11.6 installed. The Jax version I'm using is 0.2.26.