Closed quattro closed 1 year ago
Hmm, that is -- you want to solve an $Ax = b$ problem where both $A$ and $b$ are integers, and the construction of the problem is such that $x$ is known to have an integer value as well?
What kind of linear solver are you currently using to solve this?
Hi @patrick-kidger , I'm not suggesting that b
must be integer, only A
. Without going into too much detail, this is related to space/memory, where we can use small integer representation (e.g., int8), since we only need to cover 3 values (which can be further improved by leveraging sparsity of the 0s).
With some additional thought on the linear solver, I can see how this is limiting/non-sensical (A^-1
cannot be int if A
is int). But my use case would be in the abstraction of matrices through linear operators, rather than needing the solver framework. Namely, representing a centered/scaled version of A
as Z := (A - CB)S
, where C
, B
are the covariates/intercept and effects, and S
is a scaling term.
Ah, you're primarily just interested in representing the linear operator, not using it to solve a linear system? You could do that like this, then:
import equinox as eqx
import jax
import jax.experimental.sparse as js
import jax.numpy as jnp
import lineax as lx
class SparseQuantisedOperator(lx.AbstractLinearOperator):
matrix: js.BCOO
def mv(self, vector):
return js.sparsify(lambda m, v: m @ v)(self.matrix, vector)
def as_matrix(self):
raise ValueError("Refusing to materialise sparse matrix.")
# Or you could do:
# return self.matrix.todense()
def transpose(self):
return SparseQuantisedOperator(self.matrix.T)
def in_structure(self):
_, in_size = self.matrix.shape
return jax.ShapeDtypeStruct((in_size,), jnp.float32)
def out_structure(self):
out_size, _ = self.matrix.shape
return jax.ShapeDtypeStruct((out_size,), jnp.float32)
#
# Example usage: solve a linear system.
#
@lx.is_symmetric.register
def _(op: SparseQuantisedOperator):
return False
@lx.is_negative_semidefinite.register
def _(op: SparseQuantisedOperator):
return False
@lx.linearise.register
def _(op: SparseQuantisedOperator):
return op
matrix = js.BCOO.fromdense(jnp.array([[1.0, 0.0], [0.0, 1.0]]))
op = SparseQuantisedOperator(matrix)
sol = lx.linear_solve(op, jnp.array([1., 2.]), solver=lx.NormalCG(rtol=1e-5, atol=1e-5))
print(sol.value)
Thanks @patrick-kidger , this has been working nicely for our internal testing. I have a remaining engineering question, however. Would it be more sensible to engineer this class as a generic that takes as a semantic input its floating-point type for in/out structure:
class SparseQuantisedOperator(lx.AbstractLinearOperator, Generic[T]):
matrix: js.BCOO
...
sp_op = SparseQuantisedOperator[jnp.float32](...)
Or is it more acceptable to take explicitly as input during class construction,
class SparseQuantisedOperator(lx.AbstractLinearOperator):
matrix: js.BCOO
dtype
...
sp_op = SparseQuantisedOperator(..., jnp.float32)
The reason I ask is that float32/float64 differences can crop up depending on x64 default values.
I think probably the generic approach. Otherwise you may have that operator.matrix.dtype != operator.dtype
; it's generally better to make illegal data structures unrepresentable.
Hi all, I'm super excited to start using the new lineax library to re-implement some of my own linear operator functionality in statistical genetics. We commonly use integer valued matrices to represent genotype data (0/1/2) and can represent a centered/scaled genotype matrix as a linear operator (I'll skip details for now). Currently, lineax enforces float-valued operators, as examplified by
Is there room for either an Int operator, or to use generic types to generalize the matrix operator by underlying numeric type?