andres-fr / skerch

Sketched matrix decompositions for PyTorch
MIT License
64 stars 2 forks source link

Skerch logo, light mode Skerch logo, dark mode

skerch: Sketched matrix decompositions for PyTorch

| PyPI | Docs | CI | Tests | |:--------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| | [![PyPI - Downloads](https://img.shields.io/pypi/dm/skerch?style=flat&label=skerch)](https://pypi.org/project/skerch/) | [![Documentation Status](https://readthedocs.org/projects/skerch/badge/?version=latest)](https://skerch.readthedocs.io/en/latest/?badge=latest) | [![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/andres-fr/skerch/ci.yaml)](https://github.com/andres-fr/skerch/actions) | [![Coverage Status](https://coveralls.io/repos/github/andres-fr/skerch/badge.svg?branch=main)](https://coveralls.io/github/andres-fr/skerch?branch=main) |

skerch is a Python package to compute different decompositions (SVD, Hermitian Eigendecomposition, diagonal, subdiagonal, triangular, block-triangular) of linear operators via sketched methods.

References:

See the documentation for more details, including examples for other decompositions and use cases.

Installation and basic usage

Install via:

pip install skerch

The sketched SVD of a linear operator op of shape (h, w) can be then computed simply via:

from skerch.decompositions import ssvd

q, u, s, vt, pt = ssvd(
    op,
    op_device=DEVICE,
    op_dtype=DTYPE,
    outer_dim=NUM_OUTER,
    inner_dim=NUM_INNER,
)

Where the number of outer and inner measurements for the sketch is specified, and q @ u @ diag(s) @ vt @ pt is a PyTorch matrix that approximates op, where q, p are thin orthonormal matrices of shape (h, NUM_OUTER) and (NUM_OUTER, w) respectively, and u, vt are small orthogonal matrices of shape (NUM_OUTER, NUM_OUTER).

The op object must simply satify the following criteria:

skerch provides a convenience PyTorch wrapper for the cases where op interacts with NumPy arrays instead (e.g. SciPy linear operators like the ones used in CurvLinOps).

To get a good suggestion of the number of measurements required for a given shape and budget, simply run:

python -m skerch prio_hpars --shape=100,200 --budget=12345

The library also implements cheap a-posteriori methods to estimate the error of the obtained sketched approximation:

from skerch.a_posteriori import a_posteriori_error
from skerch.linops import CompositeLinOp, DiagonalLinOp

# (q, u, s, vt, pt) previously computed via ssvd
sketched_op = CompositeLinOp(
    (
        ("Q", q),
        ("U", u),
        ("S", DiagonalLinOp(s)),
        ("Vt", vt),
        ("Pt", pt),
    )
)

(f1, f2, frob_err) = a_posteriori_error(
    op, sketched_op, NUM_A_POSTERIORI, dtype=DTYPE, device=DEVICE
)[0]
print("Estimated Frob(op):", f1**0.5)
print("Estimated Frob(sketched_op):", f2**0.5)
print("Estimated Frobenius Error:", frob_err**0.5)

For a given NUM_A_POSTERIORI measurements (30 is generally OK), the probability of frob_err**0.5 being wrong by a certain amount can be queried as follows:

python -m skerch post_bounds --apost_n=30 --apost_err=0.5

See Getting Started, Examples, and API docs for more details.

Developers

Contributions are most welcome under this repo's LICENSE. Feel free to open an issue with bug reports, feature requests, etc.

The documentation contains a For Developers section with useful guidelines to interact with this repo.