Open tomsturges opened 1 year ago
Which backend are you interested in?
scipy.linalg.solve_banded
uses LAPACK's gbsv
, which would also be an option for JAX on CPU.
See jaxlib/lapack.py for examples: https://github.com/google/jax/blob/a4382d7600a5fb0a8b5f8df3ab8671554b1b8b27/jaxlib/lapack.py
As another option, Lineax (https://github.com/google/lineax) now exists and offers a tridiagonal (Thomas) solver.
It also offers a TridiagonalLinearOperator
which allows for representing such operators efficiently (just three vectors).
I'm also using JAX in the context of solving PDEs that produce general banded / tridiagonal (edit: also complex) matrices, but mostly using GPUs.
I tried a similar version of the tridiagonal solver in Lineax
, but performance was disappointing (on par with a dense lu_solve
).
Are there any CUDA banded solvers that we could link to directly?
Thanks for the responses. I have access to both CPU and GPU. @shoyer I didn't really understand what that python file you linked does (sorry my JAX/python knowledge is fairly low). @patrick-kidger thanks, Lineax looks promising, I'll report back if and when I return to the project I was working on before.
Hello,
It would be great if we could get an implementation of scipy.linalg.solve_banded in jax. Otherwise I am interested in similar methods for solving matrix equations of the form
A*x=b
whereA
is a sparse matrix (in my case tridiagonal) and can also be non-Hermitian. Thanks!Tom