wesselb / lab

A generic interface for linear algebra backends
MIT License
68 stars 5 forks source link

LAB

CI Coverage Status Latest Docs Code style: black

A generic interface for linear algebra backends: code it once, run it on any backend

Requirements and Installation

pip install backends

Basic Usage

The basic use case for the package is to write code that automatically determines the backend to use depending on the types of its arguments.

Example:

import lab as B
import lab.autograd    # Load the AutoGrad extension.
import lab.torch       # Load the PyTorch extension.
import lab.tensorflow  # Load the TensorFlow extension.
import lab.jax         # Load the JAX extension.

def objective(matrix):
    outer_product = B.matmul(matrix, matrix, tr_b=True)
    return B.mean(outer_product)

The AutoGrad, PyTorch, TensorFlow, and JAX extensions are not loaded automatically to not enforce a dependency on all three frameworks. An extension can alternatively be loaded via import lab.autograd as B.

Run it with NumPy and AutoGrad:

>>> import autograd.numpy as np

>>> objective(B.randn(np.float64, 2, 2))
0.15772589216756833

>>> grad(objective)(B.randn(np.float64, 2, 2))
array([[ 0.23519042, -1.06282928],
       [ 0.23519042, -1.06282928]])

Run it with TensorFlow:

>>> import tensorflow as tf

>>> objective(B.randn(tf.float64, 2, 2))
<tf.Tensor 'Mean:0' shape=() dtype=float64>

Run it with PyTorch:

>>> import torch

>>> objective(B.randn(torch.float64, 2, 2))
tensor(1.9557, dtype=torch.float64)

Run it with JAX:

>>> import jax

>>> import jax.numpy as jnp

>>> jax.jit(objective)(B.randn(jnp.float32, 2, 2))
DeviceArray(0.3109299, dtype=float32)

>>> jax.jit(jax.grad(objective))(B.randn(jnp.float32, 2, 2))
DeviceArray([[ 0.2525182, -1.26065  ],
             [ 0.2525182, -1.26065  ]], dtype=float32)

List of Types

This section lists all available types, which can be used to check types of objects or extend functions.

General

Int          # Integers
Float        # Floating-point numbers
Complex      # Complex numbers
Bool         # Booleans
Number       # Numbers
Numeric      # Numerical objects, including booleans
DType        # Data type
Framework    # Anything accepted by supported frameworks
Device       # Any device type

NumPy

NPNumeric
NPDType
NPRandomState

NP  # Anything NumPy

AutoGrad

AGNumeric
AGDType
AGRandomState

AG  # Anything AutoGrad

TensorFlow

TFNumeric
TFDType
TFRandomState
TFDevice

TF  # Anything TensorFlow

PyTorch

TorchNumeric
TorchDType
TorchDevice
TorchRandomState

Torch  # Anything PyTorch

JAX

JAXNumeric
JAXDType
JAXDevice
JAXRandomState

JAX  # Anything JAX

List of Methods

This section lists all available constants and methods.

Arguments must be given as arguments and keyword arguments must* be given as keyword arguments. For example, sum(tensor, axis=1) is valid, but sum(tensor, 1) is not.

See the documentation for more detailed descriptions of each function.

Special Variables

default_dtype          # Default data type.
epsilon                # Magnitude of diagonal to regularise matrices with.
cholesky_retry_factor  # Retry the Cholesky, increasing `epsilon` by a factor at most this.

Constants

nan
pi
log_2_pi

Data Types

dtype(a)
dtype_float(dtype)
dtype_float(a)
dtype_int(dtype)
dtype_int(a)

promote_dtypes(*dtype)
issubdtype(dtype1, dtype2)

Generic

isabstract(a)
jit(f, **kw_args)

isnan(a)
real(a)
imag(a)

device(a)
on_device(device)
on_device(a)
set_global_device(device)
to_active_device(a)

zeros(dtype, *shape)
zeros(*shape)
zeros(ref)

ones(dtype, *shape)
ones(*shape)
ones(ref)

zero(dtype)
zero(*refs)

one(dtype)
one(*refs)

eye(dtype, *shape)
eye(*shape)
eye(ref)

linspace(dtype, a, b, num)
linspace(a, b, num)

range(dtype, start, stop, step)
range(dtype, stop)
range(dtype, start, stop)
range(start, stop, step)
range(start, stop)
range(stop)

cast(dtype, a)

identity(a)
round(a)
floor(a)
ceil(a)
negative(a)
abs(a)
sign(a)
sqrt(a)
exp(a)
log(a)
log1p(a)
sin(a)
arcsin(a)
cos(a)
arccos(a)
tan(a)
arctan(a)
tanh(a)
arctanh(a)
loggamma(a)
logbeta(a)
erf(a)
sigmoid(a)
softplus(a)
relu(a)

add(a, b)
subtract(a, b)
multiply(a, b)
divide(a, b)
power(a, b)
minimum(a, b)
maximum(a, b)
leaky_relu(a, alpha)

softmax(a, axis=None)

min(a, axis=None, squeeze=True)
max(a, axis=None, squeeze=True)
sum(a, axis=None, squeeze=True)
prod(a, axis=None, squeeze=True)
mean(a, axis=None, squeeze=True)
std(a, axis=None, squeeze=True)
logsumexp(a, axis=None, squeeze=True)
all(a, axis=None, squeeze=True)
any(a, axis=None, squeeze=True)

nansum(a, axis=None, squeeze=True)
nanprod(a, axis=None, squeeze=True)
nanmean(a, axis=None, squeeze=True)
nanstd(a, axis=None, squeeze=True)

argmin(a, axis=None)
argmax(a, axis=None)

lt(a, b)
le(a, b)
gt(a, b)
ge(a, b)
eq(a, b)
ne(a, b)

bvn_cdf(a, b, c)

cond(condition, f_true, f_false, xs**)
where(condition, a, b)
scan(f, xs, *init_state)

sort(a, axis=-1, descending=False)
argsort(a, axis=-1, descending=False)
quantile(a, q, axis=None)

to_numpy(a)
jit_to_numpy(a)  # Caches results for `B.jit`.

Linear Algebra

transpose(a, perm=None) (alias: t, T)
matmul(a, b, tr_a=False, tr_b=False) (alias: mm, dot)
einsum(equation, *elements)
trace(a, axis1=0, axis2=1)
kron(a, b)
svd(a, compute_uv=True)
eig(a, compute_eigvecs=True)
solve(a, b)
inv(a)
pinv(a)
det(a) 
logdet(a) 
expm(a)
logm(a)
cholesky(a) (alias: chol)

cholesky_solve(a, b)  (alias: cholsolve)
triangular_solve(a, b, lower_a=True) (alias: trisolve)
toeplitz_solve(a, b, c) (alias: toepsolve)
toeplitz_solve(a, c)

outer(a, b)
reg(a, diag=None, clip=True)

pw_dists2(a, b)
pw_dists2(a)
pw_dists(a, b)
pw_dists(a)

ew_dists2(a, b)
ew_dists2(a)
ew_dists(a, b)
ew_dists(a)

pw_sums2(a, b)
pw_sums2(a)
pw_sums(a, b)
pw_sums(a)

ew_sums2(a, b)
ew_sums2(a)
ew_sums(a, b)
ew_sums(a)

Random

set_random_seed(seed) 
create_random_state(dtype, seed=0)
global_random_state(dtype)
global_random_state(a)
set_global_random_state(state)

rand(state, dtype, *shape)
rand(dtype, *shape)
rand(*shape)
rand(state, ref)
rand(ref)

randn(state, dtype, *shape)
randn(dtype, *shape)
randn(*shape)
randn(state, ref)
randn(ref)

randcat(state, p, *shape)
randcat(p, *shape)

choice(state, a, *shape, p=None)
choice(a, *shape, p=None)

randint(state, dtype, *shape, lower=0, upper)
randint(dtype, *shape, lower=0, upper)
randint(*shape, lower=0, upper)
randint(state, ref, lower=0, upper)
randint(ref, lower=0, upper)

randperm(state, dtype, n)
randperm(dtype, n)
randperm(n)

randgamma(state, dtype, *shape, alpha, scale)
randgamma(dtype, *shape, alpha, scale)
randgamma(*shape, alpha, scale)
randgamma(state, ref, *, alpha, scale)
randgamma(ref, *, alpha, scale)

randbeta(state, dtype, *shape, alpha, beta)
randbeta(dtype, *shape, alpha, beta)
randbeta(*shape, alpha, beta)
randbeta(state, ref, *, alpha, beta)
randbeta(ref, *, alpha, beta)

Shaping

shape(a, *dims)
rank(a)
length(a) (alias: size)
is_scalar(a)
expand_dims(a, axis=0, times=1)
squeeze(a, axis=None)
uprank(a, rank=2)
downrank(a, rank=2, preserve=False)
broadcast_to(a, *shape)

diag(a)
diag_extract(a)
diag_construct(a)
flatten(a)
vec_to_tril(a, offset=0)
tril_to_vec(a, offset=0)
stack(*elements, axis=0)
unstack(a, axis=0, squeeze=True)
reshape(a, *shape)
concat(*elements, axis=0)
concat2d(*rows)
tile(a, *repeats)
take(a, indices_or_mask, axis=0)
submatrix(a, indices_or_mask)

Devices

You can get the device of a tensor with B.device(a), and you can execute a computation on a device by entering B.on_device(device) as a context:

with B.on_device("gpu:0"):
    a = B.randn(tf.float32, 2, 2)
    b = B.randn(tf.float32, 2, 2)
    c = a @ b

Within such a context, a tensor that is not on the active device can be moved to the active device with B.to_active_device(a).

You can also globally set the active device with B.set_global_device("gpu:0").

Lazy Shapes

If a function is evaluated abstractly, then elements of the shape of a tensor, e.g. B.shape(a)[0], may also be tensors, which can break dispatch. By entering B.lazy_shapes(), shapes and elements of shapes will be wrapped in a custom type to fix this issue.

with B.lazy_shapes():
    a = B.eye(2)
    print(type(B.shape(a)))
    # <class 'lab.shape.Shape'>
    print(type(B.shape(a)[0]))
    # <class 'lab.shape.Dimension'>

Random Numbers

If you call a random number generator without providing a random state, e.g. B.randn(np.float32, 2), the global random state from the corresponding backend is used. For JAX, since there is no global random state, LAB provides a JAX global random state accessible through B.jax_global_random_state once lab.jax is loaded.

If you do not want to use a global random state but rather explicitly maintain one, you can create a random state with B.create_random_state and then pass this as the first argument to the random number generators. The random number generators will then return a tuple containing the updated random state and the random result.

# Create random state.
state = B.create_random_state(tf.float32, seed=0)

# Generate two random arrays.
state, x = B.randn(state, tf.float32, 2)
state, y = B.randn(state, tf.float32, 2)

Control Flow Cache

Coming soon!