patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
365 stars 24 forks source link

Differentiating w.r.t. initial guess throws an error #104

Open romanodev opened 3 months ago

romanodev commented 3 months ago

The gradient of the solution of a linear system solved iteratively w.r.t. to the initial guess should be zero. Instead, the following snippet


import lineax as lx
from jax import numpy as jnp
import jax

operator = lx.MatrixLinearOperator(jnp.array([[1,0],[0,1]]),tags=lx.positive_semidefinite_tag)
b = jnp.array([1.,2.])

def f(x0):

  return lx.linear_solve(operator, b,options={'y0':x0}, solver=lx.CG(atol=1e-12,rtol=1e-12)).value.sum()

x0 = jnp.zeros(2)

print(jax.grad(f)(x0))

gives [lineax version 0.0.5]


Traceback (most recent call last):
  File "/home/romanodev/Project/JAX-BTE/test_lineax.py", line 15, in <module>
    print(jax.grad(f)(x0))
          ^^^^^^^^^^^^^^^
  File "/home/romanodev/Project/JAX-BTE/test_lineax.py", line 11, in f
    return lx.linear_solve(operator, b,options={'y0':x0}, solver=lx.GMRES(atol=1e-12,rtol=1e-12)).value.sum()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Unexpected tangent. `lineax.linear_solve(..., options=...)` cannot be autodifferentiated.

The problem is quickly resolved by using jax.lax.stop_gradient


import lineax as lx
from jax import numpy as jnp
import jax

operator = lx.MatrixLinearOperator(jnp.array([[1,0],[0,1]]),tags=lx.positive_semidefinite_tag)
b = jnp.array([1.,2.])

def f(x0):

  return lx.linear_solve(operator, b,options={'y0':jax.lax.stop_gradient(x0)}, solver=lx.CG(atol=1e-12,rtol=1e-12)).value.sum()

x0 = jnp.zeros(2)

print(jax.grad(f)(x0))
[0. 0.]

For reference, JAX's solver works fine


from jax import numpy as jnp
import jax

A = jnp.array([[1,0],[0,1]])
b = jnp.array([1.,2.])

def f(x0):

  return jax.scipy.sparse.linalg.cg(lambda x:A.dot(x), b,tol=1e-10,x0=x0)[0].sum()

x0 = jnp.zeros(2)

print(jax.grad(f)(x0))
[0. 0.]

Even though this is a corner case, it may happen that the first guess is traced (it was my use case) in a more complex computational graph. Also, it would be great to be able to specify the first guess for the backward pass.

patrick-kidger commented 3 months ago

I think this is working as intended. We don't support any notion of differentiating with respect to options, so a user should explicitly opt out of this -- rather than potentially getting silently unexpected gradients.

On using an initial guess for the backward pass -- indeed, right now we don't seem to support this. Probably the correct thing to do would be to just use the transpose of the initial guess for the forward pass, by filling in these two methods:

https://github.com/patrick-kidger/lineax/blob/4a7b1087cea9349d047da12420721490725de35e/lineax/_solver/cg.py#L236-L248

I'd be happy to take a PR on this!

romanodev commented 3 months ago

I see. I will look into it. Great library, BTW!