JaxGaussianProcesses / JaxLinOp

Linear Operators in JAX.
Apache License 2.0
12 stars 0 forks source link
jax linear-algebra linear-operators

This project has now been incorporated into GPJax.

JaxLinOp's logo

Linear Operators in JAX.

PyPI version

JaxLinOp is a lightweight linear operator library written in jax.

Overview

Consider solving a diagonal matrix $A$ against a vector $b$.

import jax.numpy as jnp

n = 1000
diag = jnp.linspace(1.0, 2.0, n)

A = jnp.diag(diag)
b = jnp.linspace(3.0, 4.0, n)

# A⁻¹ b
jnp.solve(A, b)

Doing so is costly in large problems. Storing the matrix gives rise to memory costs of $O(n^2)$, and inverting the matrix costs $O(n^3)$ in the number of data points $n$.

But hold on a second. Notice:

JaxLinOp is designed to exploit stucture of this kind.

import jaxlinop

A = jaxlinop.DiagonalLinearOperator(diag = diag)

# A⁻¹ b
A.solve(b)

JaxLinOp is designed to automatically reduce cost savings in matrix addition, multiplication, computing log-determinants and more, for other matrix stuctures too!

Custom Linear Operator (details to come soon)

The flexible design of JaxLinOp will allow users to impliment their own custom linear operators.

from jaxlinop import LinearOperator

class MyLinearOperator(LinearOperator):

  def __init__(self, ...)
    ...

# There will be a minimal number methods that users need to impliment for their custom operator. 
# For optimal efficiency, we'll make it easy for the user to add optional methods to their operator, 
# if they give better performance than the defaults.

Installation

Stable version

The latest stable version of jaxlinop can be installed via pip:

pip install jaxlinop

Note

We recommend you check your installation version:

python -c 'import jaxlinop; print(jaxlinop.__version__)'

Development version

Warning

This version is possibly unstable and may contain bugs.

Clone a copy of the repository to your local machine and run the setup configuration in development mode.

git clone https://github.com/JaxGaussianProcesses/JaxLinOp.git
cd jaxlinop
python -m setup develop

Note

We advise you create virtual environment before installing:

conda create -n jaxlinop_ex python=3.10.0
conda activate jaxlinop_ex

and recommend you check your installation passes the supplied unit tests:

python -m pytest tests/