jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

jax.scipy.linalg.eigh_tridiagonal() doesn't implement calculation of eigenvectors #14019

Open crowsonkb opened 1 year ago

crowsonkb commented 1 year ago

I need this for finding the eigenvectors of the Hessian after tridiagonalizing it with Lanczos iteration. Right now the function looks like:

def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False,
                     select: str = 'a', select_range: Optional[Tuple[float, float]] = None,
                     tol: Optional[float] = None) -> Array:
  if not eigvals_only:
    raise NotImplementedError("Calculation of eigenvectors is not implemented")

and it is not in the documentation that it is not implemented either, it just raises a NotImplementedError when you try to use it. If I shell out to scipy on cpu I then can't JIT the function. Are there any plans to implement calculation of eigenvectors for tridiagonal matrices?

Thank you, Katherine Crowson

mattjj commented 1 year ago

Thanks for raising this!

Are there any plans to implement calculation of eigenvectors for tridiagonal matrices?

I don't know of any plans, but we don't really plan things like this; instead we wait for users to ask for them!

If I shell out to scipy on cpu I then can't JIT the function.

As a temporary suboptimal workaround, how about using jax.pure_callback? Then you should be able to call into scipy, and jitting will still work. If you want autodiff, you could put a custom_jvp around it, like in this example.

@hawkinsp do you happen to know/remember how hard it is to add eigenvector calculations (context: #6622)? Maybe we should ask Rasmus...

mattjj commented 1 year ago

Looks like Rasmus implemented it here for TensorFlow, with great comments, so maybe we can port it?

HHalva commented 1 year ago

This would be great to have!

hawkinsp commented 1 year ago

The missing piece for porting Rasmus's implementation is a batched tridiagonal solve, I believe.

HHalva commented 1 year ago

That sounds tricky. Shame, it makes it difficult to implement any of the popular GP inference engines that use eigendecomp of tridiagonal matrices as a way to compute log-determinants cheaply.

AlexanderMath commented 1 year ago

a way to compute log-determinants cheaply.

Perhaps I'm misunderstanding. Isn't log det(M) = sum_i lg( eigh[i] ) so eigenvalues are sufficient?

HHalva commented 1 year ago

a way to compute log-determinants cheaply.

Perhaps I'm misunderstanding. Isn't log det(M) = sum_i lg( eigh[i] ) so eigenvalues are sufficient?

eigh is still expensive no? it doesnt exploit the tridiagonal structure for cheaper compute.

AlexanderMath commented 1 year ago

Sorry. I meant to say eigh_tridiagonal(M, eigvals_only=True) instead of eigh; that is, only compute eigenvalues using tridiagonal structure. For a tridiagonal matrix you can compute the log determinant as

from jax.scipy.linalg import eigh_tridiagonal
def log_det_trid(M): return jnp.sum(jnp.log(eigh_tridiagonal(M, eigvals_only=True))) # you'll need use diagonal/off-diagonal instead of M

This is true because log det(M) = log prod(eigh_tridiagonal(M, eigvals_only=True)) = sum log eigh_tridiagonal(M, eigvals_only=True)

AlexanderMath commented 1 year ago

Even jax.grad(lg(det(M)) doesn't require eigenvectors only inverse (see section 2.1.4 or eq 57 in the matrix cookbook).