wilson-labs / cola

Compositional Linear Algebra
Apache License 2.0
325 stars 24 forks source link

[Bug] slogdet for matrices of dimension larger than 1000 #91

Open neel-maniar opened 1 month ago

neel-maniar commented 1 month ago

🐛 Bug

There is an issue with beartyping that means that taking the determinant of a matrix with dimension larger than $1000\times 1000$ returns an error.

E.g. try the below code with dim=1000 and dim=1001. The bug exists regardless of whether the line Sigma = cola.PSD(Sigma) is included.

To reproduce

Code snippet to reproduce

import cola
import cola.linalg
import jax.numpy as jnp
import jax.random as jr
from jax import config

config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jax.random as jr

dim = 1001
master_key = jr.key(0)

A = jr.normal(master_key, (dim, dim))
Sigma = A @ A.T  # Ensure Sigma is PSD
Sigma += 0.1 * jnp.eye(dim)  # Ensure Sigma has strictly positive determinant
Sigma = cola.ops.Dense(Sigma)  # Convert Sigma to a cola Dense object
Sigma = cola.PSD(Sigma)  # Tell cola that Sigma is PSD

print("Signed Log Determinant:")
print("-----------------------")
print("(jax.numpy)", jnp.linalg.slogdet(Sigma.to_dense()))
print("(cola)", cola.linalg.slogdet(Sigma))

Stack trace/error message

Signed Log Determinant:
-----------------------
(jax.numpy) SlogdetResult(sign=Array(1., dtype=float64), logabsdet=Array(5935.35245205, dtype=float64))
C:\Users\neelm\miniconda3\envs\gp\lib\site-packages\beartype\_util\hint\pep\utilpeptest.py:311: BeartypeDecorHintPep585DeprecationWarning: PEP 484 type hint typing.Callable deprecated by PEP 585. This hint is scheduled for removal in the first Python version released after October 5th, 2025. To resolve this, import this hint from "beartype.typing" rather than "typing". For further commentary and alternatives, see also:
    https://beartype.readthedocs.io/en/latest/api_roar/#pep-585-deprecations
  warn(

Expected Behavior

With dim=1000, I get expected behaviour:

Signed Log Determinant:
-----------------------
(jax.numpy) SlogdetResult(sign=Array(1., dtype=float64), logabsdet=Array(5926.90396938, dtype=float64))
(cola) (Array(1., dtype=float64), Array(5926.90396938, dtype=float64))

System information

CoLA Version: 0.0.5 JaX Version: 0.4.28 Computer OS: Windows 11 Home

Additional context

Could be related to #41 but it seems like a different error.

There also seem to be issues with accuracy of cola.solve for matrices larger than 1000 by 1000.