jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.11k stars 2.76k forks source link

Jax GMRES on GPU largely slower than its scipy counterpart #9259

Open Azercoco opened 2 years ago

Azercoco commented 2 years ago

Hello, I used the following script to compare the performance of the jax gmres solver and the one from scipy :

from time import time

import jax
import scipy
from jax import jit

def solve_jax(A, b):
    return jax.scipy.sparse.linalg.gmres(lambda v: A @ v, b,solve_method='batched', atol=1e-5)

solve_jax_jit = jit(solve_jax)

A = scipy.random.rand(30, 30)
b = scipy.random.rand(30)

t1 = time()
for i in range(20):
    x = scipy.sparse.linalg.gmres(A, b, atol=1e-5, restart=20)
    #print(i, x[0][0])
t2 = time()
for i in range(20):
    x_jax = solve_jax_jit(A, b)
    #print(i, x_jax[0][0])
t3 = time()

print(f"{(t2 - t1)/20:.2e} vs {(t3 - t2)/20:.2e}")

Here was the result :

1.09e-01 vs 2.04e+00

Nearly a 20x factor of difference ! I tried to play around with the gmres parameters but it did not change the time ratio scipy/jax.

I am using a GPU backend (could this explain the performence difference ?) on Windows (my GPU is a GTX 1650) and I have Cuda 11.6 installed. The Jax version I'm using is 0.2.26.

soraros commented 2 years ago

You are currently also measuring jit compile time in your code. Maybe try something like this

solve_jax_aot = jit(solve_jax).lower(A, b).compile()
...
for i in range(...):
  x, info = solve_jax_aot(A, b)
...

It's also helpful to set a maxiter in jax.scipy.sparse.linalg.gmres.

jakevdp commented 2 years ago

Also don't forget block_until_ready(); see https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code for more information.

For what it's worth, this is the result of your benchmark, slightly modified, and run on a Colab CPU runtime:

import jax
import scipy
from jax import jit

def solve_jax(A, b):
    return jax.scipy.sparse.linalg.gmres(lambda v: A @ v, b,solve_method='batched', atol=1e-5)

solve_jax_jit = jit(solve_jax)

A = scipy.random.rand(30, 30)
b = scipy.random.rand(30)
_ = solve_jax_jit(A, b)[0].block_until_ready()

%timeit scipy.sparse.linalg.gmres(A, b, atol=1e-5, restart=20)
# 10 loops, best of 5: 127 ms per loop

%timeit solve_jax_jit(A, b)[0].block_until_ready()
# 100 loops, best of 5: 11.3 ms per loop
Azercoco commented 2 years ago

Compilation time does not seem to cause the issue ... I just ran your code and got :

121 ms ± 6.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.15 s ± 407 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

but at least the problem seems to come from my installation.

jakevdp commented 2 years ago

Can you say more about how you're running this? What versions of jax and jaxlib do you have installed? Are you on CPU or GPU? With or without X64 enabled?

Azercoco commented 2 years ago

I'm using

I installed the Windows build following these instruction (using the jaxlib-0.1.75+cuda111-cp39-none-win_amd64.whl wheel) so maybe the issue is within this build. (https://github.com/cloudhan/jax-windows-builder)

shoyer commented 2 years ago

This is a known issue with JAX's iterative methods on GPUs. XLA's loops on GPU sync back to the host CPU on each loop iteration, so they are slow if the function being solved is quick to evaluate. I actually made a benchmark for this issue that we passed off to the XLA GPU team.

The work-around is either to run this sort of computation on the CPU instead (e.g., by copying all arrays explicitly to the CPU with jax.device_put). Or you could also try using TPUs, for which XLA is able to run loops entirely on device.

We could also consider writing/leveraging custom CUDA kernels or iterative solvers on GPUs, but that would be a major undertaking and likely would have some disadvantages (like losing pytree support).

michelkluger commented 22 hours ago

having the same problem on CPU,

import numpy as np
import scipy.sparse as scipy_sparse
from scipy.sparse.linalg import lgmres as scipy_gmres
import time

def create_sparse_matrix(n, density=0.01):
    A = scipy_sparse.random(n, n, density=density, format='csr')
    return A

def benchmark_scipy_gmres(A_scipy, b, tol=1e-5, maxiter=1000):
    start_time = time.time()
    x, info = scipy_gmres(A_scipy, b)
    end_time = time.time()

    # Check convergence
    residual = np.linalg.norm(A_scipy @ x - b)
    converged = info == 0  # GMRES returns 0 if converged

    return x, end_time - start_time, converged, float(residual), int(info)

def run_benchmark(n_values, num_runs=1):
    results = []

    for n in n_values:
        scipy_times = []
        convergence_rates = []
        residuals = []
        infos = []

        for _ in range(num_runs):
            A = create_sparse_matrix(n)
            b = np.random.rand(n)

            _, scipy_time, converged, residual, info = benchmark_scipy_gmres(A, b)

            scipy_times.append(scipy_time)
            convergence_rates.append(float(converged))
            residuals.append(residual)
            infos.append(info)

        results.append({
            'n': n,
            'scipy_avg_time': np.mean(scipy_times),
            'convergence_rate': np.mean(convergence_rates),
            'avg_residual': np.mean(residuals),
            'avg_info': np.mean(infos)
        })
        print(f"Completed benchmarks for n={n}")

    return results

if __name__ == "__main__":
    n_values = [100, 500, 1000, 2000]
    results = run_benchmark(n_values)

    print("\nBenchmark Results:")
    print("n\tSciPy GMRES (s)\tConv. Rate\tAvg. Residual\tAvg. Info")
    for result in results:
        print(f"{result['n']}\t{result['scipy_avg_time']:.4f}\t\t{result['convergence_rate']:.2f}\t\t{result['avg_residual']:.2e}\t\t{result['avg_info']:.2f}")

vs

import numpy as np
import scipy.sparse as scipy_sparse
import jax
import jax.numpy as jnp
from jax.experimental import sparse
from jax.scipy.sparse.linalg import gmres as jax_gmres
import time

def create_sparse_matrix(n, density=0.01):
    A = scipy_sparse.random(n, n, density=density, format='csr')
    return A

def scipy_to_jax_sparse(scipy_sparse_mat):
    return sparse.BCSR.from_scipy_sparse(scipy_sparse_mat)

def benchmark_jax_gmres(A_scipy, b, tol=1e-5, maxiter=1000):
    A_jax = scipy_to_jax_sparse(A_scipy)
    b_jax = jnp.array(b)

    @jax.jit
    def jax_gmres_solve(A, b):
        x, info = jax_gmres(A, b)
        return x, info

    # Compile the function
    _ = jax_gmres_solve(A_jax, b_jax)

    # Benchmark the compiled function
    start_time = time.time()
    x, info = jax_gmres_solve(A_jax, b_jax)
    x = x.block_until_ready()  # Ensure computation is complete
    end_time = time.time()

    # Check convergence
    residual = jnp.linalg.norm(A_jax @ x - b_jax)
    converged = info == 0  # GMRES returns 0 if converged

    return x, end_time - start_time, converged, float(residual), int(info)

def run_benchmark(n_values, num_runs=1):
    results = []

    for n in n_values:
        jax_times = []
        convergence_rates = []
        residuals = []
        infos = []

        for _ in range(num_runs):
            A = create_sparse_matrix(n)
            b = np.random.rand(n)

            _, jax_time, converged, residual, info = benchmark_jax_gmres(A, b)

            jax_times.append(jax_time)
            convergence_rates.append(float(converged))
            residuals.append(residual)
            infos.append(info)

        results.append({
            'n': n,
            'jax_avg_time': np.mean(jax_times),
            'convergence_rate': np.mean(convergence_rates),
            'avg_residual': np.mean(residuals),
            'avg_info': np.mean(infos)
        })
        print(f"Completed benchmarks for n={n}")

    return results

if __name__ == "__main__":
    n_values = [100, 500, 1000, 2000]
    results = run_benchmark(n_values)

    print("\nBenchmark Results:")
    print("n\tJAX GMRES (s)\tConv. Rate\tAvg. Residual\tAvg. Info")
    for result in results:
        print(f"{result['n']}\t{result['jax_avg_time']:.4f}\t\t{result['convergence_rate']:.2f}\t\t{result['avg_residual']:.2e}\t\t{result['avg_info']:.2f}")
dfm commented 22 hours ago

@michelkluger — It's hard to say much here since you didn't include the results of your benchmarks, but some high level thoughts: All of the sparse support in JAX is experimental so your milage may vary on the actual performance of these APIs. But, there have been some changes in how XLA handles CPU iterations, so I'd be interested to know how the performance compares with and without the XLA_FLAGS=--xla_cpu_use_thunk_runtime=false environment variable set to know if there is a specific regression here.

michelkluger commented 21 hours ago

I tried to create a new script, to guarantee that comparison is more fair, same matrix and everything

image

import numpy as np
import scipy.sparse as sp
from scipy.sparse.linalg import gmres as scipy_gmres
import jax
import jax.numpy as jnp
from jax.experimental import sparse as jax_sparse
from jax.scipy.sparse.linalg import gmres as jax_gmres
import time
import pandas as pd

# Constants
MATRIX_DENSITY = 0.1
TOLERANCE = 1e-5
MAX_ITERATIONS = 1000
NUM_RUNS = 1
N_VALUES = [100, 200, 500, 1000]
OUTPUT_FILE = "gmres_benchmark_results.csv"

def create_sparse_matrix(n, density=MATRIX_DENSITY):
    A = sp.random(n, n, density=density, format='csr')
    return A

def benchmark_scipy_gmres(A, b, tol=TOLERANCE, maxiter=MAX_ITERATIONS):
    start_time = time.time()
    try:
        x, info = scipy_gmres(A, b, atol=tol, maxiter=maxiter, restart=min(maxiter, A.shape[0]))
        end_time = time.time()
        residual = np.linalg.norm(A @ x - b)
        return end_time - start_time, info == 0, residual
    except Exception as e:
        print(f"SciPy GMRES error: {str(e)}")
        return np.nan, False, np.nan

@jax.jit
def jax_gmres_solve(A, b, tol=TOLERANCE, maxiter=MAX_ITERATIONS):
    x, info = jax_gmres(A, b, tol=tol, maxiter=maxiter, restart=min(maxiter, A.shape[1]))
    return x, info

def benchmark_jax_gmres(A, b, tol=TOLERANCE, maxiter=MAX_ITERATIONS):
    A_jax = jax_sparse.BCOO.from_scipy_sparse(A)
    b_jax = jnp.array(b)

    # Warm-up run
    _ = jax_gmres_solve(A_jax, b_jax)

    start_time = time.time()
    try:
        x, info = jax_gmres_solve(A_jax, b_jax)
        x = x.block_until_ready()
        end_time = time.time()

        residual = jnp.linalg.norm(A_jax @ x - b_jax)
        return end_time - start_time, info == 0, float(residual)
    except Exception as e:
        print(f"JAX GMRES error: {str(e)}")
        return np.nan, False, np.nan

def run_benchmark(n_values, num_runs=NUM_RUNS):
    results = []

    for n in n_values:
        print(f"Running benchmark for n={n}")
        for _ in range(num_runs):
            A = create_sparse_matrix(n)
            b = np.random.rand(n)

            scipy_time, scipy_conv, scipy_res = benchmark_scipy_gmres(A, b)
            jax_time, jax_conv, jax_res = benchmark_jax_gmres(A, b)

            results.append({
                'n': n,
                'scipy_time': scipy_time,
                'scipy_conv': scipy_conv,
                'scipy_res': scipy_res,
                'jax_time': jax_time,
                'jax_conv': jax_conv,
                'jax_res': jax_res,
                'speedup': scipy_time / jax_time if jax_time > 0 else np.nan
            })

    return pd.DataFrame(results)

if __name__ == "__main__":
    results_df = run_benchmark(N_VALUES)

    summary = results_df.groupby('n').mean()
    summary = summary[['scipy_time', 'jax_time', 'speedup', 'scipy_conv', 'jax_conv', 'scipy_res', 'jax_res']]
    summary.columns = ['SciPy Time (s)', 'JAX Time (s)', 'Speedup', 'SciPy Conv. Rate', 'JAX Conv. Rate', 'SciPy Residual', 'JAX Residual']

    print("\nBenchmark Results Summary:")
    pd.set_option('display.float_format', lambda x: f"{x:.4f}")
    print(summary)

    summary.to_csv(OUTPUT_FILE)
    print(f"\nResults have been saved to '{OUTPUT_FILE}'")
michelkluger commented 21 hours ago

@michelkluger — It's hard to say much here since you didn't include the results of your benchmarks, but some high level thoughts: All of the sparse support in JAX is experimental so your milage may vary on the actual performance of these APIs. But, there have been some changes in how XLA handles CPU iterations, so I'd be interested to know how the performance compares with and without the XLA_FLAGS=--xla_cpu_use_thunk_runtime=false environment variable set to know if there is a specific regression here.

I added a new script for more fair comparison, I am surprised, because from all I read, Jax was meant to be so much faster, but I am very new to that, so it is very much possible that I made some mistakes, be happy to learn from that

dfm commented 21 hours ago

Thanks! I ran this script locally and I can confirm that the performance with XLA_FLAGS=--xla_cpu_use_thunk_runtime=false is comparable to scipy, but without we have a major performance regression.

Pinging @ezhulenev who has been looking into related performance issues.

dfm commented 21 hours ago

@michelkluger — For your specific use case, I'd recommend setting that environment variable for now.

michelkluger commented 21 hours ago

@michelkluger — For your specific use case, I'd recommend setting that environment variable for now.

I ran with the flag,

image

Set the XLA flag

os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'

tested with

# get XLA_FLAGS env variable
def get_xla_flags():
    try:
        return os.environ["XLA_FLAGS"]
    except KeyError:
        return None
get_xla_flags()

'--xla_cpu_use_thunk_runtime=false'

jakevdp commented 20 hours ago

What happens if you replace the sparse.BCSR matrix with a dense matrix? The jax.experimental.sparse code is just a reference implementation, and is quite slow even for simple operations like sparse matmul. I wouldn't be surprised if slow sparse matmul is driving the performance here.