google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.78k stars 2.72k forks source link

jax spsolve works fine on cpu but raised error on gpu! #22500

Open mehranmirramezani opened 1 month ago

mehranmirramezani commented 1 month ago

Description

I am developing an FEM solver using JAX using sparse matrices. The code works perfectly fine on cpu but the exact same code fails on gpu by giving this error:

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: Singular matrix in linear solve.

Attached is the simplified version of the code that was used:

import jax
import jax.numpy as jnp
from jax.experimental.sparse import BCOO
from jax.experimental.sparse import BCSR
from jax.experimental.sparse.linalg import spsolve
jax.config.update('jax_enable_x64', True)

def shape_function(loc_crd):
    shape_fn = jnp.array([
        (1 - loc_crd[0] - loc_crd[1]),
        (loc_crd[0]),
        (loc_crd[1]),
    ])
    return shape_fn

def grad_shape_fn(loc_crd):
    return jax.jacfwd(shape_function)(loc_crd)

def loc_to_glob_map(loc_crd, glob_crd):
    return shape_function(loc_crd) @ glob_crd

def grad_loc_glob_map(loc_crd, glob_crd):
    return jax.jacfwd(loc_to_glob_map)(loc_crd, glob_crd)

def grad_shape_wrt_glob(loc_crd, glob_crd):
    return (jnp.linalg.inv(grad_loc_glob_map(loc_crd, glob_crd).T) @ grad_shape_fn(loc_crd).T).T

def grad_disp(disp, loc_crd, glob_crd):
    return (grad_shape_wrt_glob(loc_crd, glob_crd).T @ disp.reshape(3,2)).T

def weak_form(disp, loc_crd, glob_crd, E, nu):
    lmbda = E * nu /((1+nu) * (1-2*nu))
    mu = E /(2 * (1+nu))
    grad_u = grad_disp(disp, loc_crd, glob_crd)
    strain = 0.5 * (grad_u + grad_u.T)
    sigma = lmbda * jnp.trace(strain) * jnp.eye(2) + 2 * mu * strain
    return thickness * jnp.linalg.det(grad_loc_glob_map(loc_crd, glob_crd)) * jnp.sum(strain * sigma)

def stiff_ele(disp, loc_crd, glob_crd, E, nu):
    quad_weight = 0.5
    return quad_weight * jax.hessian(weak_form)(disp, loc_crd, glob_crd, E, nu)

def assemble_rcv(rcv, i):
    row_id, col_id, value_id = rcv
    row_indices = jnp.repeat(connec[i, :], 6)
    col_indices = jnp.tile(connec[i, :], 6)
    values = all_Ke[i].flatten()

    row_id = jax.lax.dynamic_update_slice(row_id, row_indices, (36 * i,))
    col_id = jax.lax.dynamic_update_slice(col_id, col_indices, (36 * i,))
    value_id = jax.lax.dynamic_update_slice(value_id, values, (36 * i,))

    return (row_id, col_id, value_id), None

def dirichlet_bcs_sparse(rcv_all, hemo_ind, nonhemo_ind, N_nodes):
    row_id, col_id, value_id = rcv_all
    nho_ind, nho_val = nonhemo_ind

    value_id_e = jnp.append(value_id, 0.)
    row_id_e = jnp.append(row_id, -1)

    @jax.jit
    def apply_rhs(rhs_vec_car, val_ind):
        rhs_vec, _ = rhs_vec_car
        val, ind = val_ind
        ids = jnp.where(col_id == ind, size=jnp.shape(col_id)[0], fill_value=-1)[0]
        rhs_vec = rhs_vec.at[row_id_e[ids]].add(-val * value_id_e[ids])
        return (rhs_vec, None), None

    (rhs_vec,_), _ = jax.lax.scan(apply_rhs, (jnp.zeros(2*N_nodes + 1,),None), (nho_val,nho_ind))
    all_rhs = rhs_vec[:-1]

    diri_bcs = jnp.append(hemo_ind, nho_ind)
    red_id = jnp.arange(2*N_nodes)[jnp.isin(jnp.arange(2*N_nodes), diri_bcs, invert=True)]

    def shifted_ids(original_arr, removed_arr, corr_arr, val_arr):
        mask = jnp.isin(original_arr, removed_arr, invert=True)
        filtered_array = original_arr[mask]
        unique_numbers, shifted_array = jnp.unique(filtered_array, return_inverse=True)

        return shifted_array, corr_arr[mask], val_arr[mask]

    row_id_n, col_id_n, val_id_n = shifted_ids(row_id, diri_bcs, col_id, value_id)
    col_id_red, row_id_red, val_id_red = shifted_ids(col_id_n, diri_bcs, row_id_n, val_id_n)

    return all_rhs[red_id], (row_id_red, col_id_red, val_id_red)

connec = jnp.array([[1, 2, 3, 4, 5, 6],[3, 4, 7, 8, 5, 6]]) - 1
ele_glob_crd = jnp.array([[[0., 0.], [1., 0.], [0., 1.]],[[1., 0.], [1., 1.], [0., 1.]]])
hem_ind = jnp.array([0, 1, 2], dtype=jnp.int64)
nhem_ind = (jnp.array([6]),jnp.array([0.1]))

num_hom = hem_ind.shape[0]
num_nhom = nhem_ind[0].shape[0]

N_nodes = 4
N_ele = 2

thickness, E, nu = 1., 0.05, 0.3
quad_points = jnp.array([1./3., 1./3.])
disp_nodal = jnp.zeros(6,)

vmap_Ke = jax.vmap(stiff_ele, in_axes=(None, None, 0, None, None))
all_Ke = vmap_Ke(disp_nodal, quad_points, ele_glob_crd, E, nu)

# compute sparse global stiffness matrix
row_id = jnp.zeros(36 * N_ele, dtype=jnp.int64)
col_id = jnp.zeros(36 * N_ele, dtype=jnp.int64)
value_id = jnp.zeros(36 * N_ele)
rcv_all, _ = jax.lax.scan(assemble_rcv, (row_id, col_id, value_id), jnp.arange(N_ele))
row_id, col_id, value_id = rcv_all

indices = jnp.vstack((row_id, col_id)).T
bcoo = BCOO((value_id, indices), shape=(2*N_nodes, 2*N_nodes)).sum_duplicates()
rcv_all = (bcoo.indices[:,0], bcoo.indices[:,1], bcoo.data)

# impose BCs
all_rhs, rcv_all_bc = dirichlet_bcs_sparse(rcv_all, hem_ind, nhem_ind, N_nodes)

row_id_bc, col_id_bc, value_id_bc = rcv_all_bc
indices_bc = jnp.vstack((row_id_bc, col_id_bc)).T
K_bcoo = BCOO((value_id_bc, indices_bc), shape=(2*N_nodes - (num_hom+num_nhom), 2*N_nodes - (num_hom+num_nhom))).sum_duplicates()

K_bcsr = BCSR.from_bcoo(K_bcoo)
sol_red = spsolve(K_bcsr.data,K_bcsr.indices,K_bcsr.indptr,all_rhs,tol=1e-06,reorder=1)

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.26
jaxlib: 0.4.23.dev20240502
numpy: 1.26.4
python: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]

$ nvidia-smi Wed Jul 17 17:43:15 2024
+-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 555.42.02 Driver Version: 555.42.02 CUDA Version: 12.5 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA GeForce RTX 2080 Ti On | 00000000:B1:00.0 Off | N/A | | 0% 33C P0 64W / 260W | 160MiB / 11264MiB | 0% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | 0 N/A N/A 2131480 C python 156MiB | +-----------------------------------------------------------------------------------------+

jakevdp commented 1 month ago

I suspect your matrix is very close to singular.

On CPU, the solver is implemented using float64 intermediates (via scipy's spsolve), while on GPU the solver is implemented using float32 intermediates (via cuda's csrlsvqr).

If that's the case, there's not much we can do here, unfortunately, aside from recommend that you do not pass near-singular matrices to spsolve.

mehranmirramezani commented 1 month ago

Many thanks for the details on the solver. I am wondering if there is a way to switch to cpu just for the solver part. I tried the below code but it seems still running on gpu and giving same error.

jax.device_put(spsolve(), device=jax.devices('cpu')[0])

jakevdp commented 1 month ago

device-putting the output of spsolve on CPU will not cause the operation to be executed on CPU. But if you device-put the inputs to spsolve onto the CPU, then the operation should execute on CPU.

mrkwjc commented 1 month ago

Hi! I can confirm that putting arguments onto CPU triggers solution on CPU. That's nice! It seems solving on CPU is much faster on many occasions.

I have additional question however. Is there a way to reuse the factorized matrix for many right hand sides? With scipy one can do:

factorized_K = scipy.sparse.linalg.splu(K)
x = factorized_K.solve(b)

and solve against many b vectors. Are there some obstacles to implement something like that with jax?

jakevdp commented 1 month ago

There's not currently any method of pre-factorization available in JAX. JAX's spsolve was implemented to wrap a CUDA routine that a user needed; the CPU implementation is basically a reference implementation based on pure_callback. If you're on CPU and need further funtionality, the best approach would probably be to use pure_callback directly to call that functionality in scipy.

mrkwjc commented 1 month ago

Thanks, pure_callback is the way. However, what i do not see immediately is how to preserve an object to be reused between subsequent callbacks. In other words how to perform such a sequence:

Factorized K must be 'saved' somehow, maybe as global variable?...

jakevdp commented 1 month ago

factorized_K is just an object with a few arrays and static attributes. You could probably register it as a PyTree and return it to jax, and then pass it to the next pure_callback