Closed adam-hartshorne closed 1 year ago
Hmm. It looks to me like in this case the .A
attribute of a Dense
LinearOperator is itself a LinearOperator whereas it is expected to be an array (but maybe it would not explicitly fail before this commit though operating in a way that's unintended).
Can you link the code used to product the linear operator that enters into this example?
(Later we can add some better input validation for the basic linear operators to help raise these errors earlier)
After some further research, I have found out what the issue was. GPJax function kernel.gram(...) now returns a Linear Operator of type PSD(Dense(M)), which I was then "wrapping" with Dense (as previously the return from this function was not a cola Linear Operator), and this caused the resulting error i.e.
k_zz = kernel.gram(Z)
k_zz_cola = cola.ops.Dense(k_zz) <---- This was root cause of the error
noise = jnp.exp(self.log_beta) + self.jitter
K = cola.PSD(k_zz_cola + noise * cola.ops.I_like(k_zz_cola))
l_zz = lower_cholesky(K)
removing the unnecessary extra cola.ops.Dense removes the error i.e.
k_zz = kernel.gram(Z)
noise = jnp.exp(self.log_beta) + self.jitter
K = cola.PSD(k_zz + noise * cola.ops.I_like(k_zz))
l_zz = lower_cholesky(K)
yeah in general cola.lazify
will be safer if you're not sure if the object is a linear operator
I have just tried to upgrade from commit 74406c9fb2b02869511befea10f02beb75dd3861, to the latest in the main branch, and I am now presented with the following error. I am using the lower_cholesky function provided from GPJax, https://github.com/JaxGaussianProcesses/GPJax/blob/main/gpjax/lower_cholesky.py and K is of type <15x15 Sum[cola.ops.operators.Dense, cola.ops.operators.Product[cola.ops.operators.ScalarMul, cola.ops.operators.Identity]] with dtype=float64>
Normally I would provide a MVE but I am not sure how best to test for this. Let me know how best to assist if required.