JaxGaussianProcesses / JaxLinOp

Linear Operators in JAX.
Apache License 2.0
12 stars 0 forks source link

dev: Cache expensive (no-argument) LinearOperator methods that are reused in other calculations. #7

Open daniel-dodd opened 1 year ago

daniel-dodd commented 1 year ago

Some expensive LinearOperator methods are reused in calculations. For example, the .to_root() method for a DenseLinearOperator returns the Cholesky decomposition that is costly to compute. But once computed, it is cheap to use in matrix operations such as solving linear systems, computing log-determinants, etc. Caching this would enhance performance.

daniel-dodd commented 1 year ago

Have investigated this. This has to be functional, otherwise code will break e.g., in lax.scan's.

Root caching (if the dense matrix is not needed in further calculations, or is cheap to compute from the root) can safely be achieved in the current software with:

def compute_and_cache_root(lo: LinearOperator) -> Tuple[LinearOperator, LinearOperator]
    root_lo = lo.to_root()
    cached_lo = lo.from_root(root_lo)
    return root_lo, cached_lo

Might think about adding a cache_root class method, similar to from_root that builds a new class, but that ignores dense matrix construction.