patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
365 stars 24 forks source link

GMRES and NormalCG sometimes handle singular matrices. #18

Open patrick-kidger opened 1 year ago

patrick-kidger commented 1 year ago

GMRES:

import lineax as lx
import jax.numpy as jnp

a = jnp.array([[1, 1], [0, 0]])
b = jnp.array([1, 0])
operator = lx.MatrixLinearOperator(a)
solver = lx.GMRES(rtol=1e-6, atol=1e-6)
sol = lx.linear_solve(operator, b, solver)
print(sol.value)  # [1. 0.]

Moreover note that whilst a @ sol.value == b, this is not the pseudoinverse solution.

Taken from https://users.wpi.edu/~walker/Papers/gmres-singular,SIMAX_18,1997,37-51.pdf ("GMRES on (nearly) singular systems")

This is particularly troublesome around autodiff, for which we may get incorrect gradients if a singular matrix is used.

Given the use of SVD as a subroutine within GMRES, we can probably detect this and do something smarter?

NormalCG:

import lineax as lx
import jax.numpy as jnp

matrix = jnp.array([[1.0, 0.0], [0.0, 0.0]])
vector = jnp.array([1., 2.])
out = lx.linear_solve(lx.MatrixLinearOperator(matrix), vector, solver=lx.NormalCG(rtol=1e-5, atol=1e-5))
print(out.value)