patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
339 stars 22 forks source link

Support for custom solvers #44

Open SNMS95 opened 1 year ago

SNMS95 commented 1 year ago

Hi guys,

I am using a JAX-based framework to do finite element analysis. Currently, I use solvers from two sources; either jax.scipy.bicgstab or PETSc. I like the fact that lineax has many solvers atm and is better than JAX so I will start substituting the jax solvers with lineax solvers. However, I want to maintain the petsc solver option as well since petsc has a lot more solvers and preconditioners.

Will it be possible to wrap PETSc with lineax so that the user can access them and still be integrated into a JAX workflow (no tracing errors or issues with higher order autodiff)?

_Before coming across lineax, I was thinking of use pure_callbacks + custom_jvp to do this. But this is painful, especially when you have sparse matrices involved._

patrick-kidger commented 1 year ago

Yup, that should be possible. Subclass lineax.AbstractSolver. Take a look at the implementation of any existing solver for an example, along with the docstring of each method of AbstractSolver for what needs implementing.

In particular you should be able to call out to PETSc in the .compute method using a jax.pure_callback. Lineax will handle autodiff etc automatically.