wilson-labs / cola

Compositional Linear Algebra
Apache License 2.0
325 stars 24 forks source link

[Bug] Logdet #41

Open daniel-dodd opened 10 months ago

daniel-dodd commented 10 months ago

🐛 Bug

Issue with log determinant jit compilation on large matrices > 1e-6. Perhaps an issue with the iterative method, which I believe is triggered after 1e-6.

I replaced this issue by specifying the method="dense" kwarg and seem to have no issues there.

To reproduce

# Jit compiling this function and giving an input that has larger than 1e-6 x 1e-6 shape
jit(lambda: sigma cola.logdet(sigma))( input_matrix_here)
# Here Sigma is a SumLinearOperator of Dense LinOp and Diagonal array.
# This may be an issue on SumLinearOperators.


--> 189     + cola.logdet(sigma)
    190     + diff.T @ cola.solve(sigma, diff)
    191 )

File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/cola/linalg/logdet.py:39, in logdet(A, **kwargs)
     17 @export
     18 def logdet(A: LinearOperator, **kwargs):
     19     r""" Computes logdet of a linear operator. 
     21     For large inputs (or with method='iterative'),
     37         Array: logdet
     38     """
---> 39     _, ld = slogdet(A,**kwargs)
     40     return ld

File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/plum/function.py:438, in Function.__call__(self, *args, **kw_args)
    436     method, return_type, loginfo = self.resolve_method(args, types)
    437 logging.info("%s",loginfo)
--> 438 return _convert(method(*args,**kw_args), return_type)

File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/cola/linalg/logdet.py:96, in slogdet(A, **kwargs)
     93 elif 'exact' in method or not stochastic_faster:
     94     # TODO: explicit autograd rule for this case?
     95     logA = cola.linalg.log(A, tol=tol, method='iterative', **kws)
---> 96     trlogA = cola.linalg.trace(logA,method='exact',**kws)
     97 else:
     98     raise ValueError(f"Unknown method {method} or CoLA didn't fit any selection criteria")

File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/plum/function.py:438, in Function.__call__(self, *args, **kw_args)
    436     method, return_type, loginfo = self.resolve_method(args, types)
    437 logging.info("%s",loginfo)
--> 438 return _convert(method(*args,**kw_args), return_type)

File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/cola/linalg/diag_trace.py:137, in trace(A, **kwargs)
    117 r""" Compute the trace of a linear operator tr(A).
    119 Uses either :math:`O(\tfrac{1}{\delta^2})` time stochastic estimation (Hutchinson estimator)
    134 Returns:
    135     Array: trace"""
    136 assert A.shape[0] == A.shape[1], "Can't trace non square matrix"
--> 723   return getattr(self.aval,f"_{name}")(self,*args)

File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4153, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   4150       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
   4152 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 4153 return _gather(arr,treedef,static_idx,dynamic_idx,indices_are_sorted,
   4154 unique_indices,mode,fill_value)

File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4162, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
   4159 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   4160             unique_indices, mode, fill_value):
   4161   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 4162   indexer = _index_to_gather(shape(arr),idx)  # shared with _scatter_update
   4163   y = arr
   4165   if fill_value is not None:

File /usr/share/miniconda/envs/test/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:4414, in _index_to_gather(x_shape, idx, normalize_indices)
   4405 if not all(_is_slice_element_none_or_constant(elt)
   4406            for elt in (start, stop, step)):
   4407   msg = ("Array slice indices must have static start/stop/step to be used "
   4408          "with NumPy indexing syntax. "
   4409          f"Found slice({start}, {stop}, {step}). "
   4412          "dynamic_update_slice (JAX does not support dynamically sized "
   4413          "arrays within JIT compiled functions).")
-> 4414   raise IndexError(msg)
   4415 if not core.is_constant_dim(x_shape[x_axis]):
   4416   msg = ("Cannot use NumPy slice indexing on an array dimension whose "
   4417          f"size is not statically known ({x_shape[x_axis]}). "
   4418          "Try using lax.dynamic_slice/dynamic_update_slice")

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

Error: Process completed with exit code 1.

System information

Please complete the following information:

Additional context

Add any other context about the problem here.

mfinzi commented 10 months ago

Ah, this seems to be about the trace estimation in the iterative-exact method for computing $\log \mathrm{det} A = \mathrm{Tr}(\log A)$. Looks to be coming from the slicing of I in https://github.com/wilson-labs/cola/blob/main/cola/algorithms/diagonal_estimation.py#L12 for the exact (deterministic) version of the trace estimator.

The issue may be when n (if A is a n x n matrix) is not a multiple of the trace evaluator batch size, and then the arrays in different iterations will be different.

The easiest fix would probably be to construct I_chunk (and the other chunks) that is explicitly zero padded, and to not use any slicing. I will investigate later this week.

Also for GP applications in which you only need unbiased estimates of the MLL gradients you might also want to consider SLQ (ie 'iterative-stochastic') which you can also access (will be selected by auto) if both the matrix is large enough and you set a large vtol (the tolerance for the standard deviation of the unbiased estimator) such as vtol=1/5.

mfinzi commented 10 months ago

Oh actually it's not that, on closer inspection it's the slicing of the lanczos tridiagonal matrix in apply_unary: https://github.com/wilson-labs/cola/blob/main/cola/linalg/unary.py#L33