JAX already has excellent parallelism support via jax.Array, the compiler could choose sharding so that maximum parallelize for elementwise operations and matrix multiplication. But the linear algebra operations(like LU factorization, eigenvalue problem) only run on a single GPU.
NVIDIA already has a GPU-accelerated ScaLAPACK called cuSolverMg, which provides linear algebra distributed solvers for single node multiGPU. For example, in jax.lax.lu, we could call cusolverMgGetrf to perform multiGPU linear solver.
The size of matrix is strictly restricted by GPU's memory, adding support for multiGPU linear solvers could help us solve the bigger problem. I am also willing to help with adding the features.
JAX already has excellent parallelism support via
jax.Array
, the compiler could choose sharding so that maximum parallelize for elementwise operations and matrix multiplication. But the linear algebra operations(like LU factorization, eigenvalue problem) only run on a single GPU.NVIDIA already has a GPU-accelerated ScaLAPACK called
cuSolverMg
, which provides linear algebra distributed solvers for single node multiGPU. For example, injax.lax.lu
, we could callcusolverMgGetrf
to perform multiGPU linear solver.The size of matrix is strictly restricted by GPU's memory, adding support for multiGPU linear solvers could help us solve the bigger problem. I am also willing to help with adding the features.
Thanks for the awesome work!