Open SNMS95 opened 1 year ago
Yup, that should be possible. Subclass lineax.AbstractSolver
. Take a look at the implementation of any existing solver for an example, along with the docstring of each method of AbstractSolver
for what needs implementing.
In particular you should be able to call out to PETSc in the .compute
method using a jax.pure_callback
. Lineax will handle autodiff etc automatically.
Hi guys,
I am using a JAX-based framework to do finite element analysis. Currently, I use solvers from two sources; either
jax.scipy.bicgstab
or PETSc. I like the fact thatlineax
has many solvers atm and is better than JAX so I will start substituting the jax solvers withlineax
solvers. However, I want to maintain the petsc solver option as well since petsc has a lot more solvers and preconditioners.Will it be possible to wrap PETSc with lineax so that the user can access them and still be integrated into a JAX workflow (no tracing errors or issues with higher order autodiff)?
_Before coming across lineax, I was thinking of use
pure_callbacks
+custom_jvp
to do this. But this is painful, especially when you have sparse matrices involved._