Hi 👋 , as always thanks for providing and maintaining such a wonderful library.
I was exploring the new internal functions such as sum_squares as a means to have a general approach to computing sum(X ** 2) where X is either an Array, lx.MatrixLinearOperator, or some composed linear operator (e.g., lx.AddLinearOperator).
For simple cases of Array, or lx.MatrixLinearOperator it works great, but I noticed unexpected results under the composed setting. Here is a MWE,
import jax.numpy as jnp
import jax.random as rdm
import lineax as lx
import lineax.internal as lxi
key = rdm.PRNGKey(0)
# shape sizes
K = 4
N = 100
key, x_key = rdm.split(key)
X1 = rdm.normal(x_key, shape=(N, K))
A = X1.T @ X1
Aop = lx.MatrixLinearOperator(A)
# all good here
lxi.sum_squares(Aop)
lxi.sum_squares(A)
jnp.sum(A ** 2)
# expected 0, but returns lxi.sum_squares(Aop) + lxi.sum_squares(-Aop)
# likely due to PyTree shape, leaves...
Dop = Aop - Aop
lxi.sum_squares(Dop)
key, x_key = rdm.split(key)
X2 = rdm.normal(x_key, shape=(N, K))
B = X2.T @ X2
Bop = lx.MatrixLinearOperator(B)
Dop = Aop - Bop
jnp.sum((A - B) ** 2) # 3490.8342
lxi.sum_squares(Dop) # 82372.97
lxi.sum_squares(Aop) + lxi.sum_squares(-Bop) # 82372.97
Given that this function is internal, this may not be the use case considered when it was initially implemented, and that is totally fine. My primary motivation for wanting to compute sum of squares for a general operator is due to keeping a centered and shifted sparse matrix in memory as a linear operator S = X - M @ B. For now I could just unroll the sum of squares myself by pulling out the components and handling the cross term.
Right! So sum_squares just treats its inputs as a PyTree[Array], i.e. a pytree with array-valued leaves. It's roughly equivalent to sum(jnp.sum(xi**2) for xi in jax.tree_util.tree_leaves(x)).
Hi 👋 , as always thanks for providing and maintaining such a wonderful library.
I was exploring the new internal functions such as
sum_squares
as a means to have a general approach to computingsum(X ** 2)
whereX
is either anArray
,lx.MatrixLinearOperator
, or some composed linear operator (e.g.,lx.AddLinearOperator
).For simple cases of
Array
, orlx.MatrixLinearOperator
it works great, but I noticed unexpected results under the composed setting. Here is a MWE,Given that this function is internal, this may not be the use case considered when it was initially implemented, and that is totally fine. My primary motivation for wanting to compute sum of squares for a general operator is due to keeping a centered and shifted sparse matrix in memory as a linear operator
S = X - M @ B
. For now I could just unroll the sum of squares myself by pulling out the components and handling the cross term.