patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
365 stars 24 forks source link

Internal `sum_squares` returns unexpected result for composed linear operators #70

Closed quattro closed 11 months ago

quattro commented 11 months ago

Hi 👋 , as always thanks for providing and maintaining such a wonderful library.

I was exploring the new internal functions such as sum_squares as a means to have a general approach to computing sum(X ** 2) where X is either an Array, lx.MatrixLinearOperator, or some composed linear operator (e.g., lx.AddLinearOperator).

For simple cases of Array, or lx.MatrixLinearOperator it works great, but I noticed unexpected results under the composed setting. Here is a MWE,

import jax.numpy as jnp
import jax.random as rdm
import lineax as lx
import lineax.internal as lxi

key = rdm.PRNGKey(0)

# shape sizes
K = 4
N = 100

key, x_key = rdm.split(key)
X1 = rdm.normal(x_key, shape=(N, K))
A = X1.T @ X1
Aop = lx.MatrixLinearOperator(A)

# all good here
lxi.sum_squares(Aop)
lxi.sum_squares(A)
jnp.sum(A ** 2)

# expected 0, but returns lxi.sum_squares(Aop) + lxi.sum_squares(-Aop)
# likely due to PyTree shape, leaves...
Dop = Aop - Aop
lxi.sum_squares(Dop)

key, x_key = rdm.split(key)
X2 = rdm.normal(x_key, shape=(N, K))
B = X2.T @ X2
Bop = lx.MatrixLinearOperator(B)
Dop = Aop - Bop
jnp.sum((A - B) ** 2)   # 3490.8342
lxi.sum_squares(Dop)    # 82372.97
lxi.sum_squares(Aop) + lxi.sum_squares(-Bop)  # 82372.97 

Given that this function is internal, this may not be the use case considered when it was initially implemented, and that is totally fine. My primary motivation for wanting to compute sum of squares for a general operator is due to keeping a centered and shifted sparse matrix in memory as a linear operator S = X - M @ B. For now I could just unroll the sum of squares myself by pulling out the components and handling the cross term.

patrick-kidger commented 11 months ago

Right! So sum_squares just treats its inputs as a PyTree[Array], i.e. a pytree with array-valued leaves. It's roughly equivalent to sum(jnp.sum(xi**2) for xi in jax.tree_util.tree_leaves(x)).