wilson-labs / cola

Compositional Linear Algebra
Apache License 2.0
325 stars 24 forks source link

[Bug] Logdet #41

Open daniel-dodd opened 10 months ago

daniel-dodd commented 10 months ago

🐛 Bug

Issue with log determinant jit compilation on large matrices > 1e-6. Perhaps an issue with the iterative method, which I believe is triggered after 1e-6.

I replaced this issue by specifying the method="dense" kwarg and seem to have no issues there.

To reproduce

# Jit compiling this function and giving an input that has larger than 1e-6 x 1e-6 shape
jit(lambda: sigma cola.logdet(sigma))( input_matrix_here)
# Here Sigma is a SumLinearOperator of Dense LinOp and Diagonal array.
# This may be an issue on SumLinearOperators.

https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370

--> 189     + cola.logdet(sigma)
[709](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:710)
    190     + diff.T @ cola.solve(sigma, diff)
[710](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:711)
    191 )
[711](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:712)

[712](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:713)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/cola/linalg/logdet.py:39, in logdet(A, **kwargs)
[713](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:714)
     17 @export
[714](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:715)
     18 def logdet(A: LinearOperator, **kwargs):
[715](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:716)
     19     r""" Computes logdet of a linear operator. 
[716](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:717)
     20 
[717](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:718)
     21     For large inputs (or with method='iterative'),
[718](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:719)
   (...)
[719](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:720)
     37         Array: logdet
[720](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:721)
     38     """
[721](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:722)
---> 39     _, ld = slogdet(A,**kwargs)
[722](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:723)
     40     return ld
[723](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:724)

[724](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:725)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/plum/function.py:438, in Function.__call__(self, *args, **kw_args)
[725](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:726)
    436     method, return_type, loginfo = self.resolve_method(args, types)
[726](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:727)
    437 logging.info("%s",loginfo)
[727](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:728)
--> 438 return _convert(method(*args,**kw_args), return_type)
[728](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:729)

[729](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:730)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/cola/linalg/logdet.py:96, in slogdet(A, **kwargs)
[730](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:731)
     93 elif 'exact' in method or not stochastic_faster:
[731](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:732)
     94     # TODO: explicit autograd rule for this case?
[732](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:733)
     95     logA = cola.linalg.log(A, tol=tol, method='iterative', **kws)
[733](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:734)
---> 96     trlogA = cola.linalg.trace(logA,method='exact',**kws)
[734](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:735)
     97 else:
[735](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:736)
     98     raise ValueError(f"Unknown method {method} or CoLA didn't fit any selection criteria")
[736](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:737)

[737](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:738)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/plum/function.py:438, in Function.__call__(self, *args, **kw_args)
[738](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:739)
    436     method, return_type, loginfo = self.resolve_method(args, types)
[739](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:740)
    437 logging.info("%s",loginfo)
[740](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:741)
--> 438 return _convert(method(*args,**kw_args), return_type)
[741](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:742)

[742](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:743)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/cola/linalg/diag_trace.py:137, in trace(A, **kwargs)
[743](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:744)
    117 r""" Compute the trace of a linear operator tr(A).
[744](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:745)
    118 
[745](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:746)
    119 Uses either :math:`O(\tfrac{1}{\delta^2})` time stochastic estimation (Hutchinson estimator)
[746](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:747)
   (...)
[747](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:748)
    134 Returns:
[748](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:749)
    135     Array: trace"""
[749](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:750)
    136 assert A.shape[0] == A.shape[1], "Can't trace non square matrix"
[800](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:801)
--> 723   return getattr(self.aval,f"_{name}")(self,*args)
[801](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:802)

[802](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:803)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4153, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
[803](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:804)
   4150       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
[804](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:805)
   4152 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
[805](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:806)
-> 4153 return _gather(arr,treedef,static_idx,dynamic_idx,indices_are_sorted,
[806](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:807)
   4154 unique_indices,mode,fill_value)
[807](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:808)

[808](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:809)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4162, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
[809](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:810)
   4159 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
[810](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:811)
   4160             unique_indices, mode, fill_value):
[811](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:812)
   4161   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
[812](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:813)
-> 4162   indexer = _index_to_gather(shape(arr),idx)  # shared with _scatter_update
[813](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:814)
   4163   y = arr
[814](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:815)
   4165   if fill_value is not None:
[815](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:816)

[816](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:817)
File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4414, in _index_to_gather(x_shape, idx, normalize_indices)
[817](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:818)
   4405 if not all(_is_slice_element_none_or_constant(elt)
[818](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:819)
   4406            for elt in (start, stop, step)):
[819](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:820)
   4407   msg = ("Array slice indices must have static start/stop/step to be used "
[820](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:821)
   4408          "with NumPy indexing syntax. "
[821](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:822)
   4409          f"Found slice({start}, {stop}, {step}). "
[822](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:823)
   (...)
[823](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:824)
   4412          "dynamic_update_slice (JAX does not support dynamically sized "
[824](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:825)
   4413          "arrays within JIT compiled functions).")
[825](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:826)
-> 4414   raise IndexError(msg)
[826](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:827)
   4415 if not core.is_constant_dim(x_shape[x_axis]):
[827](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:828)
   4416   msg = ("Cannot use NumPy slice indexing on an array dimension whose "
[828](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:829)
   4417          f"size is not statically known ({x_shape[x_axis]}). "
[829](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:830)
   4418          "Try using lax.dynamic_slice/dynamic_update_slice")
[830](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:831)

[831](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:832)
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
[832](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:833)

[833](https://github.com/JaxGaussianProcesses/GPJax/actions/runs/6099328110/job/16551211143?pr=370#step:8:834)
Error: Process completed with exit code 1.

System information

Please complete the following information:

Additional context

Add any other context about the problem here.

mfinzi commented 10 months ago

Ah, this seems to be about the trace estimation in the iterative-exact method for computing $\log \mathrm{det} A = \mathrm{Tr}(\log A)$. Looks to be coming from the slicing of I in https://github.com/wilson-labs/cola/blob/main/cola/algorithms/diagonal_estimation.py#L12 for the exact (deterministic) version of the trace estimator.

The issue may be when n (if A is a n x n matrix) is not a multiple of the trace evaluator batch size, and then the arrays in different iterations will be different.

The easiest fix would probably be to construct I_chunk (and the other chunks) that is explicitly zero padded, and to not use any slicing. I will investigate later this week.

Also for GP applications in which you only need unbiased estimates of the MLL gradients you might also want to consider SLQ (ie 'iterative-stochastic') which you can also access (will be selected by auto) if both the matrix is large enough and you set a large vtol (the tolerance for the standard deviation of the unbiased estimator) such as vtol=1/5.

mfinzi commented 10 months ago

Oh actually it's not that, on closer inspection it's the slicing of the lanczos tridiagonal matrix in apply_unary: https://github.com/wilson-labs/cola/blob/main/cola/linalg/unary.py#L33