jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.26k stars 2.78k forks source link

cuSolverMG support for distributed arrays #16597

Open allegro0132 opened 1 year ago

allegro0132 commented 1 year ago

JAX already has excellent parallelism support via jax.Array, the compiler could choose sharding so that maximum parallelize for elementwise operations and matrix multiplication. But the linear algebra operations(like LU factorization, eigenvalue problem) only run on a single GPU.

NVIDIA already has a GPU-accelerated ScaLAPACK called cuSolverMg, which provides linear algebra distributed solvers for single node multiGPU. For example, in jax.lax.lu, we could call cusolverMgGetrf to perform multiGPU linear solver.

The size of matrix is strictly restricted by GPU's memory, adding support for multiGPU linear solvers could help us solve the bigger problem. I am also willing to help with adding the features.

Thanks for the awesome work!

Happy2Git commented 1 year ago

Great suggestion!!!