wilson-labs / cola

Compositional Linear Algebra
Apache License 2.0
401 stars 26 forks source link

[Bug] AttributeError: 'DenseLinearOperator' object has no attribute 'astype' when applying cholesky #65

Closed adam-hartshorne closed 1 year ago

adam-hartshorne commented 1 year ago

I have just tried to upgrade from commit 74406c9fb2b02869511befea10f02beb75dd3861, to the latest in the main branch, and I am now presented with the following error. I am using the lower_cholesky function provided from GPJax, https://github.com/JaxGaussianProcesses/GPJax/blob/main/gpjax/lower_cholesky.py and K is of type <15x15 Sum[cola.ops.operators.Dense, cola.ops.operators.Product[cola.ops.operators.ScalarMul, cola.ops.operators.Identity]] with dtype=float64>

    l_zz = lower_cholesky(K)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/plum/function.py", line 438, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/media/adam/shared_drive/PycharmProjects/Process_Shape_Datasets/lower_cholesky.py", line 38, in lower_cholesky
    return cola.ops.Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 184, in to_dense
    return self @ self.xnp.eye(self.shape[-1], self.shape[-1], dtype=self.dtype, device=self.device)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 211, in __matmul__
    return self._matmat(X)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 159, in _matmat
    return sum(M @ v for M in self.Ms)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 159, in <genexpr>
    return sum(M @ v for M in self.Ms)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 211, in __matmul__
    return self._matmat(X)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 25, in _matmat
    return self.xnp.cast(self.A, dtype) @ self.xnp.cast(X, dtype)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py", line 168, in cast
    return array.astype(dtype)
AttributeError: 'DenseLinearOperator' object has no attribute 'astype'. Did you mean: 'dtype'?

Normally I would provide a MVE but I am not sure how best to test for this. Let me know how best to assist if required.

mfinzi commented 1 year ago

Hmm. It looks to me like in this case the .A attribute of a Dense LinearOperator is itself a LinearOperator whereas it is expected to be an array (but maybe it would not explicitly fail before this commit though operating in a way that's unintended). Can you link the code used to product the linear operator that enters into this example?

(Later we can add some better input validation for the basic linear operators to help raise these errors earlier)

adam-hartshorne commented 1 year ago

After some further research, I have found out what the issue was. GPJax function kernel.gram(...) now returns a Linear Operator of type PSD(Dense(M)), which I was then "wrapping" with Dense (as previously the return from this function was not a cola Linear Operator), and this caused the resulting error i.e.

k_zz = kernel.gram(Z)
k_zz_cola = cola.ops.Dense(k_zz) <---- This was root cause of the error
noise = jnp.exp(self.log_beta) + self.jitter
K = cola.PSD(k_zz_cola + noise * cola.ops.I_like(k_zz_cola))
l_zz = lower_cholesky(K)

removing the unnecessary extra cola.ops.Dense removes the error i.e.

k_zz = kernel.gram(Z)
noise = jnp.exp(self.log_beta) + self.jitter
K = cola.PSD(k_zz + noise * cola.ops.I_like(k_zz))
l_zz = lower_cholesky(K)
mfinzi commented 1 year ago

yeah in general cola.lazify will be safer if you're not sure if the object is a linear operator