JaxLinOp
is a lightweight linear operator library written in jax
.
Consider solving a diagonal matrix $A$ against a vector $b$.
import jax.numpy as jnp
n = 1000
diag = jnp.linspace(1.0, 2.0, n)
A = jnp.diag(diag)
b = jnp.linspace(3.0, 4.0, n)
# A⁻¹ b
jnp.solve(A, b)
Doing so is costly in large problems. Storing the matrix gives rise to memory costs of $O(n^2)$, and inverting the matrix costs $O(n^3)$ in the number of data points $n$.
But hold on a second. Notice:
JaxLinOp
is designed to exploit stucture of this kind.
import jaxlinop
A = jaxlinop.DiagonalLinearOperator(diag = diag)
# A⁻¹ b
A.solve(b)
JaxLinOp
is designed to automatically reduce cost savings in matrix addition, multiplication, computing log-determinants and more, for other matrix stuctures too!
The flexible design of JaxLinOp
will allow users to impliment their own custom linear operators.
from jaxlinop import LinearOperator
class MyLinearOperator(LinearOperator):
def __init__(self, ...)
...
# There will be a minimal number methods that users need to impliment for their custom operator.
# For optimal efficiency, we'll make it easy for the user to add optional methods to their operator,
# if they give better performance than the defaults.
The latest stable version of jaxlinop
can be installed via pip
:
pip install jaxlinop
Note
We recommend you check your installation version:
python -c 'import jaxlinop; print(jaxlinop.__version__)'
Warning
This version is possibly unstable and may contain bugs.
Clone a copy of the repository to your local machine and run the setup configuration in development mode.
git clone https://github.com/JaxGaussianProcesses/JaxLinOp.git
cd jaxlinop
python -m setup develop
Note
We advise you create virtual environment before installing:
conda create -n jaxlinop_ex python=3.10.0 conda activate jaxlinop_ex
and recommend you check your installation passes the supplied unit tests:
python -m pytest tests/