Open daniel-dodd opened 10 months ago
Ah, this seems to be about the trace estimation in the iterative-exact
method for computing $\log \mathrm{det} A = \mathrm{Tr}(\log A)$. Looks to be coming from the slicing of I in https://github.com/wilson-labs/cola/blob/main/cola/algorithms/diagonal_estimation.py#L12 for the exact (deterministic) version of the trace estimator.
The issue may be when n (if A is a n x n matrix) is not a multiple of the trace evaluator batch size, and then the arrays in different iterations will be different.
The easiest fix would probably be to construct I_chunk
(and the other chunks) that is explicitly zero padded, and to not use any slicing. I will investigate later this week.
Also for GP applications in which you only need unbiased estimates of the MLL gradients you might also want to consider SLQ (ie 'iterative-stochastic') which you can also access (will be selected by auto) if both the matrix is large enough and you set a large vtol
(the tolerance for the standard deviation of the unbiased estimator) such as vtol=1/5
.
Oh actually it's not that, on closer inspection it's the slicing of the lanczos tridiagonal matrix in apply_unary: https://github.com/wilson-labs/cola/blob/main/cola/linalg/unary.py#L33
🐛 Bug
Issue with log determinant jit compilation on large matrices > 1e-6. Perhaps an issue with the
iterative
method, which I believe is triggered after 1e-6.I replaced this issue by specifying the
method="dense"
kwarg and seem to have no issues there.To reproduce
https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370
System information
Please complete the following information:
main
branch.Additional context
Add any other context about the problem here.