patrick-kidger / lineax

Linear solvers in JAX and Equinox.
Apache License 2.0
365 stars 24 forks source link

Integer-valued operators #27

Closed quattro closed 1 year ago

quattro commented 1 year ago

Hi all, I'm super excited to start using the new lineax library to re-implement some of my own linear operator functionality in statistical genetics. We commonly use integer valued matrices to represent genotype data (0/1/2) and can represent a centered/scaled genotype matrix as a linear operator (I'll skip details for now). Currently, lineax enforces float-valued operators, as examplified by

        if not jnp.issubdtype(matrix, jnp.inexact):
            matrix = matrix.astype(jnp.float32)

Is there room for either an Int operator, or to use generic types to generalize the matrix operator by underlying numeric type?

patrick-kidger commented 1 year ago

Hmm, that is -- you want to solve an $Ax = b$ problem where both $A$ and $b$ are integers, and the construction of the problem is such that $x$ is known to have an integer value as well?

What kind of linear solver are you currently using to solve this?

quattro commented 1 year ago

Hi @patrick-kidger , I'm not suggesting that b must be integer, only A. Without going into too much detail, this is related to space/memory, where we can use small integer representation (e.g., int8), since we only need to cover 3 values (which can be further improved by leveraging sparsity of the 0s).

With some additional thought on the linear solver, I can see how this is limiting/non-sensical (A^-1 cannot be int if A is int). But my use case would be in the abstraction of matrices through linear operators, rather than needing the solver framework. Namely, representing a centered/scaled version of A as Z := (A - CB)S, where C, B are the covariates/intercept and effects, and S is a scaling term.

patrick-kidger commented 1 year ago

Ah, you're primarily just interested in representing the linear operator, not using it to solve a linear system? You could do that like this, then:

import equinox as eqx
import jax
import jax.experimental.sparse as js
import jax.numpy as jnp
import lineax as lx

class SparseQuantisedOperator(lx.AbstractLinearOperator):
    matrix: js.BCOO

    def mv(self, vector):
        return js.sparsify(lambda m, v: m @ v)(self.matrix, vector)

    def as_matrix(self):
        raise ValueError("Refusing to materialise sparse matrix.")
        # Or you could do:
        # return self.matrix.todense()

    def transpose(self):
        return SparseQuantisedOperator(self.matrix.T)

    def in_structure(self):
        _, in_size = self.matrix.shape
        return jax.ShapeDtypeStruct((in_size,), jnp.float32)

    def out_structure(self):
        out_size, _ = self.matrix.shape
        return jax.ShapeDtypeStruct((out_size,), jnp.float32)

# Example usage: solve a linear system.

def _(op: SparseQuantisedOperator):
    return False

def _(op: SparseQuantisedOperator):
    return False

def _(op: SparseQuantisedOperator):
    return op

matrix = js.BCOO.fromdense(jnp.array([[1.0, 0.0], [0.0, 1.0]]))
op = SparseQuantisedOperator(matrix)
sol = lx.linear_solve(op, jnp.array([1., 2.]), solver=lx.NormalCG(rtol=1e-5, atol=1e-5))
quattro commented 1 year ago

Thanks @patrick-kidger , this has been working nicely for our internal testing. I have a remaining engineering question, however. Would it be more sensible to engineer this class as a generic that takes as a semantic input its floating-point type for in/out structure:

class SparseQuantisedOperator(lx.AbstractLinearOperator, Generic[T]):
    matrix: js.BCOO

sp_op = SparseQuantisedOperator[jnp.float32](...)

Or is it more acceptable to take explicitly as input during class construction,

class SparseQuantisedOperator(lx.AbstractLinearOperator):
    matrix: js.BCOO

sp_op = SparseQuantisedOperator(..., jnp.float32)

The reason I ask is that float32/float64 differences can crop up depending on x64 default values.

patrick-kidger commented 1 year ago

I think probably the generic approach. Otherwise you may have that operator.matrix.dtype != operator.dtype; it's generally better to make illegal data structures unrepresentable.