Open crowsonkb opened 1 year ago
Thanks for raising this!
Are there any plans to implement calculation of eigenvectors for tridiagonal matrices?
I don't know of any plans, but we don't really plan things like this; instead we wait for users to ask for them!
If I shell out to scipy on cpu I then can't JIT the function.
As a temporary suboptimal workaround, how about using jax.pure_callback
? Then you should be able to call into scipy, and jitting will still work. If you want autodiff, you could put a custom_jvp
around it, like in this example.
@hawkinsp do you happen to know/remember how hard it is to add eigenvector calculations (context: #6622)? Maybe we should ask Rasmus...
Looks like Rasmus implemented it here for TensorFlow, with great comments, so maybe we can port it?
This would be great to have!
The missing piece for porting Rasmus's implementation is a batched tridiagonal solve, I believe.
That sounds tricky. Shame, it makes it difficult to implement any of the popular GP inference engines that use eigendecomp of tridiagonal matrices as a way to compute log-determinants cheaply.
a way to compute log-determinants cheaply.
Perhaps I'm misunderstanding. Isn't log det(M) = sum_i lg( eigh[i] )
so eigenvalues are sufficient?
a way to compute log-determinants cheaply.
Perhaps I'm misunderstanding. Isn't
log det(M) = sum_i lg( eigh[i] )
so eigenvalues are sufficient?
eigh
is still expensive no? it doesnt exploit the tridiagonal structure for cheaper compute.
Sorry. I meant to say eigh_tridiagonal(M, eigvals_only=True)
instead of eigh
; that is, only compute eigenvalues using tridiagonal structure. For a tridiagonal matrix you can compute the log determinant as
from jax.scipy.linalg import eigh_tridiagonal
def log_det_trid(M): return jnp.sum(jnp.log(eigh_tridiagonal(M, eigvals_only=True))) # you'll need use diagonal/off-diagonal instead of M
This is true because log det(M) = log prod(eigh_tridiagonal(M, eigvals_only=True)) = sum log eigh_tridiagonal(M, eigvals_only=True)
Even jax.grad(lg(det(M))
doesn't require eigenvectors only inverse (see section 2.1.4 or eq 57 in the matrix cookbook).
I need this for finding the eigenvectors of the Hessian after tridiagonalizing it with Lanczos iteration. Right now the function looks like:
and it is not in the documentation that it is not implemented either, it just raises a
NotImplementedError
when you try to use it. If I shell out to scipy on cpu I then can't JIT the function. Are there any plans to implement calculation of eigenvectors for tridiagonal matrices?Thank you, Katherine Crowson