jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.15k stars 2.76k forks source link

Request for implementation of scipy.linalg.solve_banded #15880

Open tomsturges opened 1 year ago

tomsturges commented 1 year ago

Hello,

It would be great if we could get an implementation of scipy.linalg.solve_banded in jax. Otherwise I am interested in similar methods for solving matrix equations of the form A*x=b where A is a sparse matrix (in my case tridiagonal) and can also be non-Hermitian. Thanks!

Tom

shoyer commented 1 year ago

Which backend are you interested in?

scipy.linalg.solve_banded uses LAPACK's gbsv, which would also be an option for JAX on CPU.

See jaxlib/lapack.py for examples: https://github.com/google/jax/blob/a4382d7600a5fb0a8b5f8df3ab8671554b1b8b27/jaxlib/lapack.py

patrick-kidger commented 1 year ago

As another option, Lineax (https://github.com/google/lineax) now exists and offers a tridiagonal (Thomas) solver.

It also offers a TridiagonalLinearOperator which allows for representing such operators efficiently (just three vectors).

hrkz commented 1 year ago

I'm also using JAX in the context of solving PDEs that produce general banded / tridiagonal (edit: also complex) matrices, but mostly using GPUs.

I tried a similar version of the tridiagonal solver in Lineax, but performance was disappointing (on par with a dense lu_solve).

Are there any CUDA banded solvers that we could link to directly?

tomsturges commented 1 year ago

Thanks for the responses. I have access to both CPU and GPU. @shoyer I didn't really understand what that python file you linked does (sorry my JAX/python knowledge is fairly low). @patrick-kidger thanks, Lineax looks promising, I'll report back if and when I return to the project I was working on before.