CoLA is a framework for scalable linear algebra, automatically exploiting the structure often found in machine learning problems and beyond. CoLA natively supports PyTorch, Jax, as well as (limited) Numpy if Jax is not installed.
pip install cola-ml
solve(A,b)
, eig(A)
, logdet(A)
, exp(A)
, trace(A)
, diag(A)
, sqrt(A)
.See https://cola.readthedocs.io/en/latest/ for our full documentation and many examples.
+, -
,
multiply by constants *, /
, matrix multiply them @
and combine them in other ways:
kron, kronsum, block_diag
etc.
import jax.numpy as jnp
import cola
A = cola.ops.Diagonal(jnp.arange(5) + .1) B = cola.ops.Dense(jnp.array([[2., 1.], [-2., 1.1], [.01, .2]])) C = B.T @ B D = C + 0.01 * cola.ops.I_like(C) E = cola.ops.Kronecker(A, cola.ops.Dense(jnp.ones((2, 2)))) F = cola.ops.BlockDiag(E, D)
v = jnp.ones(F.shape[-1]) print(F @ v)
[0.2 0.2 2.2 2.2 4.2 4.2 6.2 6.2 8.2 8.2 7.8 2.1 ]
2. **Performing Linear Algebra**. With these objects we can perform linear algebra operations even when they are very big.
```python
print(cola.linalg.trace(F))
Q = F.T @ F + 1e-3 * cola.ops.I_like(F)
b = cola.linalg.inv(Q) @ v
print(jnp.linalg.norm(Q @ b - v))
print(cola.linalg.eig(F, k=F.shape[0])[0][:5])
print(cola.linalg.sqrt(A))
31.2701
0.0010193728
[ 2.0000000e-01+0.j 0.0000000e+00+0.j 2.1999998e+00+0.j
-1.1920929e-07+0.j 4.1999998e+00+0.j]
diag([0.31622776 1.0488088 1.4491377 1.7606816 2.0248456 ])
For many of these functions, if we know additional information about the matrices we can annotate them to enable the algorithms to run faster.
Qs = cola.SelfAdjoint(Q)
%timeit cola.linalg.inv(Q) @ v
%timeit cola.linalg.inv(Qs) @ v
import torch
A = cola.ops.Dense(torch.Tensor([[1., 2.], [3., 4.]]))
print(cola.linalg.trace(cola.kron(A, A)))
import jax.numpy as jnp A = cola.ops.Dense(jnp.array([[1., 2.], [3., 4.]])) print(cola.linalg.trace(cola.kron(A, A)))
tensor(25.) 25.0
CoLA also supports autograd (and jit):
```python
from jax import grad, jit, vmap
def myloss(x):
A = cola.ops.Dense(jnp.array([[1., 2.], [3., x]]))
return jnp.ones(2) @ cola.linalg.inv(A) @ jnp.ones(2)
g = jit(vmap(grad(myloss)))(jnp.array([.5, 10.]))
print(g)
[-0.06611571 -0.12499995]
If you use CoLA, please cite the following paper:
@article{potapczynski2023cola,
title={{CoLA: Exploiting Compositional Structure for Automatic and Efficient Numerical Linear Algebra}},
author={Andres Potapczynski and Marc Finzi and Geoff Pleiss and Andrew Gordon Wilson},
journal={arXiv preprint arXiv:2309.03060},
year={2023}
}
Linear Algebra | inverse | eig | diag | trace | logdet | exp | sqrt | f(A) | SVD | pseudoinverse |
---|---|---|---|---|---|---|---|---|---|---|
Implementation | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
LinearOperators | Diag | BlockDiag | Kronecker | KronSum | Sparse | Jacobian | Hessian | Fisher | Concatenated | Triangular | FFT | Tridiagonal |
---|---|---|---|---|---|---|---|---|---|---|---|---|
Implementation | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
Annotations | SelfAdjoint | PSD | Unitary |
---|---|---|---|
Implementation | ✓ | ✓ | ✓ |
Backends | PyTorch | Jax | Numpy |
---|---|---|---|
Implementation | ✓ | ✓ | Most operations |
See the contributing guidelines docs/CONTRIBUTING.md for information on submitting issues and pull requests.
CoLA is Apache 2.0 licensed.
Please raise an issue if you find a bug or slow performance when using CoLA.