wilson-labs / cola

Compositional Linear Algebra
Apache License 2.0
325 stars 24 forks source link

Compositional Linear Algebra (CoLA)

Documentation tests codecov PyPI version Paper Downloads

CoLA is a framework for scalable linear algebra, automatically exploiting the structure often found in machine learning problems and beyond. CoLA natively supports PyTorch, Jax, as well as (limited) Numpy if Jax is not installed.

Installation

pip install cola-ml

Features in CoLA

See https://cola.readthedocs.io/en/latest/ for our full documentation and many examples.

Quick start guide

  1. LinearOperators. The core object in CoLA is the LinearOperator. You can add and subtract them +, -, multiply by constants *, /, matrix multiply them @ and combine them in other ways: kron, kronsum, block_diag etc.
    
    import jax.numpy as jnp
    import cola

A = cola.ops.Diagonal(jnp.arange(5) + .1) B = cola.ops.Dense(jnp.array([[2., 1.], [-2., 1.1], [.01, .2]])) C = B.T @ B D = C + 0.01 * cola.ops.I_like(C) E = cola.ops.Kronecker(A, cola.ops.Dense(jnp.ones((2, 2)))) F = cola.ops.BlockDiag(E, D)

v = jnp.ones(F.shape[-1]) print(F @ v)

[0.2 0.2 2.2 2.2 4.2 4.2 6.2 6.2 8.2 8.2 7.8 2.1 ]


2. **Performing Linear Algebra**. With these objects we can perform linear algebra operations even when they are very big.
```python
print(cola.linalg.trace(F))
Q = F.T @ F + 1e-3 * cola.ops.I_like(F)
b = cola.linalg.inv(Q) @ v
print(jnp.linalg.norm(Q @ b - v))
print(cola.linalg.eig(F, k=F.shape[0])[0][:5])
print(cola.linalg.sqrt(A))
31.2701
0.0010193728
[ 2.0000000e-01+0.j  0.0000000e+00+0.j  2.1999998e+00+0.j
 -1.1920929e-07+0.j  4.1999998e+00+0.j]
diag([0.31622776 1.0488088  1.4491377  1.7606816  2.0248456 ])

For many of these functions, if we know additional information about the matrices we can annotate them to enable the algorithms to run faster.

Qs = cola.SelfAdjoint(Q)
%timeit cola.linalg.inv(Q) @ v
%timeit cola.linalg.inv(Qs) @ v
  1. JAX and PyTorch. We support both ML frameworks.
    
    import torch
    A = cola.ops.Dense(torch.Tensor([[1., 2.], [3., 4.]]))
    print(cola.linalg.trace(cola.kron(A, A)))

import jax.numpy as jnp A = cola.ops.Dense(jnp.array([[1., 2.], [3., 4.]])) print(cola.linalg.trace(cola.kron(A, A)))

tensor(25.) 25.0


CoLA also supports autograd (and jit):
```python
from jax import grad, jit, vmap

def myloss(x):
    A = cola.ops.Dense(jnp.array([[1., 2.], [3., x]]))
    return jnp.ones(2) @ cola.linalg.inv(A) @ jnp.ones(2)

g = jit(vmap(grad(myloss)))(jnp.array([.5, 10.]))
print(g)
[-0.06611571 -0.12499995]

Citing us

If you use CoLA, please cite the following paper:

Andres Potapczynski, Marc Finzi, Geoff Pleiss, and Andrew Gordon Wilson. "CoLA: Exploiting Compositional Structure for Automatic and Efficient Numerical Linear Algebra." 2023.

@article{potapczynski2023cola,
  title={{CoLA: Exploiting Compositional Structure for Automatic and Efficient Numerical Linear Algebra}},
  author={Andres Potapczynski and Marc Finzi and Geoff Pleiss and Andrew Gordon Wilson},
  journal={arXiv preprint arXiv:2309.03060},
  year={2023}
}

Features implemented

Linear Algebra inverse eig diag trace logdet exp sqrt f(A) SVD pseudoinverse
Implementation
LinearOperators Diag BlockDiag Kronecker KronSum Sparse Jacobian Hessian Fisher Concatenated Triangular FFT Tridiagonal
Implementation
Annotations SelfAdjoint PSD Unitary
Implementation
Backends PyTorch Jax Numpy
Implementation Most operations

Contributing

See the contributing guidelines docs/CONTRIBUTING.md for information on submitting issues and pull requests.

CoLA is Apache 2.0 licensed.

Support and contact

Please raise an issue if you find a bug or slow performance when using CoLA.