google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.31k stars 2.68k forks source link

Support autodiff of Eigendecomposition with repeated eigenvalues #669

Open sethaxen opened 5 years ago

sethaxen commented 5 years ago

On v0.1.25 on OSX, I get the following error when computing gradients from the following jit-compiled function.

import numpy as onp
import jax.numpy as np
from jax import grad, jit

def test(x):
    val, vec = np.linalg.eigh(x)
    return np.real(np.sum(val))

grad_test = jit(grad(test))
grad_test_jc = jit(grad(jit(test)))

x = onp.eye(3, dtype=onp.double)
xc = onp.eye(3, dtype=onp.complex)

print(test(x))
print(grad_test(x))
print(grad_test_jc(x))
print(grad_test(xc))
3.0
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
[[1.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 1.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 1.+0.j]]

So far so good. But computing the gradient of the jit-compiled function with complex inputs errors

print(grad_test_jc(xc))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-10b24cdf8a93> in <module>
     19 
     20 
---> 21 print(grad_test_jc(xc))

/usr/local/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    105     _check_args(jaxtupletree_args)
    106     jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
--> 107     jaxtupletree_out = xla.xla_call(jaxtree_fun, *jaxtupletree_args)
    108     return build_tree(out_tree(), jaxtupletree_out)
    109 

/usr/local/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, f, *args, **kwargs)
    543   if top_trace is None:
    544     with new_sublevel():
--> 545       ans = primitive.impl(f, *args, **kwargs)
    546   else:
    547     tracers = map(top_trace.full_raise, args)

/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args)
    452   fun, out_tree = flatten_fun(fun, in_trees)
    453 
--> 454   compiled_fun = xla_callable(fun, *map(abstractify, flat_args))
    455   try:
    456     flat_ans = compiled_fun(*flat_args)

/usr/local/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(f, *args)
    206       if len(cache) > max_size:
    207         cache.popitem(last=False)
--> 208       ans = call(f, *args)
    209       cache[key] = (ans, f)
    210     return ans

/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_callable(fun, *abstract_args)
    473     jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master).call_wrapped(pvals)
    474     assert not env  # no subtraces here (though cond might eventually need them)
--> 475     compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)
    476     del master, consts, jaxpr, env
    477   handle_result = result_handler(result_shape)

/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in compile_jaxpr(jaxpr, const_vals, *abstract_args)
    135 def compile_jaxpr(jaxpr, const_vals, *abstract_args):
    136   arg_shapes = list(map(xla_shape, abstract_args))
--> 137   built_c = jaxpr_computation(jaxpr, const_vals, (), *arg_shapes)
    138   result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape())
    139   return built_c.Compile(arg_shapes, xb.get_compile_options(),

/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in jaxpr_computation(jaxpr, const_vals, freevar_shapes, *arg_shapes)
    173             map(c.GetShape, map(read, const_bindings + freevar_bindings)),
    174             *in_shapes)
--> 175         for subjaxpr, const_bindings, freevar_bindings in eqn.bound_subjaxprs]
    176     subfuns = [(subc, tuple(map(read, const_bindings + freevar_bindings)))
    177                for subc, (_, const_bindings, freevar_bindings)

/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in <listcomp>(.0)
    173             map(c.GetShape, map(read, const_bindings + freevar_bindings)),
    174             *in_shapes)
--> 175         for subjaxpr, const_bindings, freevar_bindings in eqn.bound_subjaxprs]
    176     subfuns = [(subc, tuple(map(read, const_bindings + freevar_bindings)))
    177                for subc, (_, const_bindings, freevar_bindings)

/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in jaxpr_computation(jaxpr, const_vals, freevar_shapes, *arg_shapes)
    167   for eqn in jaxpr.eqns:
    168     in_nodes = map(read, eqn.invars)
--> 169     in_shapes = map(c.GetShape, in_nodes)
    170     subcs = [
    171         jaxpr_computation(

/usr/local/lib/python3.7/site-packages/jax/util.py in safe_map(f, *args)
     41   for arg in args[1:]:
     42     assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
---> 43   return list(map(f, *args))
     44 
     45 

/usr/local/lib/python3.7/site-packages/jaxlib/xla_client.py in GetShape(self, operand)
    876 
    877   def GetShape(self, operand):
--> 878     return _wrap_shape(self._builder.GetShape(operand))
    879 
    880   def SetOpMetadata(self, op_metadata):

RuntimeError: Invalid argument: Binary op add with different element types: c64[3,3] and f32[1,3].

Jax built from source produced the same error.

sethaxen commented 5 years ago

Might not be related, but even without jit compilation and complex inputs, gradient computation of a function of eigenvectors fails.

import numpy as onp
import jax.numpy as np
from jax import grad

def test(x):
    val, vec = np.linalg.eigh(x)
    return np.real(np.sum(vec))

grad_test = grad(test)

x = onp.eye(3, dtype=onp.double)

print(test(x))
print(grad_test(x))
3.0
[[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
hawkinsp commented 5 years ago

PR #670 fixes the first bug; we were incorrectly declaring that the eigenvalues of a complex matrix were complex, leading to a type error when under a jit scope.

Not sure about the second bug yet.

hawkinsp commented 5 years ago

I'm wondering if the second case is happening because one of the assumptions of the JVP rule for eigh is that the eigenvalues are distinct. In this case, all the eigenvalues are 1.

This is at the limits of my linear algebra knowledge, @mattjj do you have insights here?

hawkinsp commented 5 years ago

See the discussion on the implementation here: https://github.com/google/jax/blob/master/jax/lax_linalg.py#L155

mattjj commented 5 years ago

Wow, that's a nice comment! Thanks @levskaya.

I don't have any insights. That comment taught me things. I think the case of repeated eigenvalues might come down to "contributions welcome".

sethaxen commented 5 years ago

Thanks @hawkinsp for the quick response and fix! Looks like repeated eigenvalues is probably the cause of the second bug

import numpy as onp
import jax.numpy as np
from jax import grad

def test(x):
    val, vec = np.linalg.eigh(x)
    return np.real(np.sum(vec))

grad_test = grad(test)

onp.random.seed(42)
x = onp.diag(
    onp.ones(3, dtype=onp.double) +
    onp.random.normal(0, 1e-6, size=3)
)

x2 = onp.diag(
    onp.ones(3, dtype=onp.double) +
    onp.random.normal(0, 1e-8, size=3)
)

print(grad_test(x))
print(grad_test(x2))
[[-5.9604641e-08  0.0000000e+00  0.0000000e+00]
 [ 0.0000000e+00 -5.9604641e-08  0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  1.1920929e-07]]
[[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
mattjj commented 5 years ago

@sdaxen do you need autodiff of eigendecomposition with repeated eigenvalues? If not, we should probably close this issue until someone actually asks for it. (That way we can keep all the "enhancement" issues tracking things that users have specifically asked for.)

sethaxen commented 5 years ago

@mattjj no, it's not a priority. I do need repeated eigenvalues, but I'm only test driving jax for the moment while doing my main work with a different system. Feel free to close.

mattjj commented 5 years ago

Thanks for the info!

We're very interested to hear about JAX's shortcomings so that we can work to fix them, so if there's something about JAX that makes it unsuitable for your work, please let us know!

@hawkinsp is it a better policy to close this issue, or leave it open and just keep in mind that we don't have any users specifically asking for it yet?

hawkinsp commented 5 years ago

I think we should leave these kinds of issues open. It makes them more easily searchable should someone else have the same problem; I'd rather have one issue than two.

mattjj commented 5 years ago

Cool, makes sense to me! At least we clarified how to prioritize it, then.

hawkinsp commented 5 years ago

I also observe that TF has the same limitation: https://github.com/tensorflow/tensorflow/blob/f33aa592f92e233aeb00198d0caab80eaa89afe9/tensorflow/python/ops/linalg_grad.py#L314

MiroFurtado commented 3 years ago

Forgive my general lack of knowledge, I'm just beginning to look into jax/remind myself of some linear algebra. But on this issue:

import numpy as onp
import jax.numpy as np
from jax import grad

def test(x):
    val, vec = np.linalg.eigh(x)
    return np.real(np.sum(vec))

grad_test = grad(test)

x = onp.eye(3, dtype=onp.double)

print(test(x))
print(grad_test(x))

Is the derivative here even well defined? It seems to me like in the degenerate case, there is no unique eigenvector so the np.real(np.sum(vec)) could be a whole range of possible values depending on your choice of basis, no? This is interesting stuff, I'd be curious how you go about learning more.

levskaya commented 3 years ago

As you note the general problem is quite tricky. To use physics parlance, there are two cases: you can have a matrix w. degenerate eigenvalues where the perturbation (gradient direction) "breaks the symmetry" and causes the degenerate eigenvalues to split, and then you have the case where the perturbation preserves the degeneracy... which generally makes talking about eigenvector derivatives very tricky / ill-defined with simple approaches. Especially if you're dealing with the general complex case where the eigenvector phase has additional freedom.

There are a few papers that seem to offer general algorithmic approaches, but they're complicated enough that no one has sat down to try to implement them to see how they'd work:

shoyer commented 3 years ago

The hard case is differentiating eigenvectors in the presence of degeneracies. Eigenvalue derivatives are still fine, either way.

I believe my pull request #1665 actually has a working JVP (forward mode) gradient implementation for eigh with degeneracies, but it can't be transposed, which means it doesn't work for backward mode differentiation.

In general, I don't think it's possible to define backwards mode gradients of eigenvectors for arbitrary functions of degenerate eigenvectors -- the gradients simply don't always exist. I'll see if I can work-up a good counter-example.

From a practical perspective, it seems like the better idea is to differentiate a higher level function like a power-series that does always have a well defined derivatives. Typically your program that uses eigenvectors corresponding to degenerate eigenvalues is ultimately using eigenvectors to calculate something like this anyways, because otherwise it's output would depend on arbitrary choices from the linear algebra library.

EDIT: to clarify, by "power series" I really mean "matrix valued function" here

MiroFurtado commented 3 years ago

@shoyer Just so I'm clear, you're saying that if we have a matrix A = P D P^-1, typically the reason that you would want to do the eigendecomposition is so that you can evaluation a function f(A) by doing f(A) = P f(D) P^-1 which is independent of the arbitrary choice of eigenvector (and other uses would be out of scope)? If so, I'm still confused as to how the np.sum operation that is at issue here would be defined?

Or are you talking about an entirely different application of power-series that I'm unfamiliar with.

shoyer commented 3 years ago

@shoyer Just so I'm clear, you're saying that if we have a matrix A = P D P^-1, typically the reason that you would want to do the eigendecomposition is so that you can evaluation a function f(A) by doing f(A) = P f(D) P^-1 which is independent of the arbitrary choice of eigenvector (and other uses would be out of scope)?

That's exactly right. I hypothesize that every real-world use case for calculating eigenvectors is using them in order to evaluate a matrix-valued function of some form.

If so, I'm still confused as to how the np.sum operation that is at issue here would be defined?

The example in the first post was with eigenvalue directives. As noted in https://github.com/google/jax/issues/669#issuecomment-489328464, it's been fixed.

shoyer commented 3 years ago

It's probably worth noting that the example failure case for eigenvector derivatives from https://github.com/google/jax/issues/669#issuecomment-489303348 is not well-defined matrix-valued function:

def test(x):
    val, vec = np.linalg.eigh(x)
    return np.real(np.sum(val))

E.g., suppose x = np.eye(2). Then normalized eigenvectors vec could be either [[1, 0], [0, 1]] or [[1/sqrt(2), 1/sqrt(2)], [1/sqrt(2), -1/sqrt(2)]], so test(x) could be either 2 or 2/sqrt(2).

sethaxen commented 3 years ago

FWIW, I originally encountered this while playing around with the matrix exponential of hermitian matrices, which is a power series function. For such functions that internally use the eigendecomposition, we can nevertheless write forward- and reverse-mode rules that almost completely account for the degeneracy of the eigendecomposition.

shoyer commented 3 years ago

I recently stumbled across this paper, which seems to provide exactly the algorithm we need here:

It looks like we could use the forward derivative for a JVP rule in JAX, which would suffice for auto-diff as long as we know how to implement and transpose a Sylvester solve (i.e., scipy.linalg.solve_sylvester):

Screen Shot 2021-02-10 at 1 37 16 PM

sethaxen commented 3 years ago

It looks like we could use the forward derivative for a JVP rule in JAX, which would suffice for auto-diff as long as we know how to implement and transpose a Sylvester solve (i.e., scipy.linalg.solve_sylvester)

Shouldn't be too hard. The JVP for a Sylvester solve is just another Sylvester solve. If X = sylvester_solve(A, B, Q), then Ẋ = sylvester_solve(A, B, Q̇ - Ȧ@X - X@Ḃ) (the solve is linear wrt its 3rd argument, so this is fine).

But Sylvester solves usually compute a Schur decomposition of A and B and then use a triangular Sylvester solve like LAPACK's trsyl, so this JVP would compute the Schur decompositions twice. It's more efficient to reuse the Schur decomposition, but I imagine in JAX that could pose problems with higher order AD? Or does JAX have a traceable Schur decomposition?

sethaxen commented 3 years ago

I spent some time going through the paper today, and as far as I can tell, this approach only handles exactly degenerate matrices, not almost-degenerate matrices, and in the case of exactly degenerate matrices, a Sylvester solver is not actually necessary and does not help things. Using their notation, for a standard eigenvalue problem M=I and M'=0.

In the paper's notation, the usual JVP in Giles' paper and elsewhere is

F_ij = {
    0                if i = j
    inv(λ_j - λ_i)   otherwise
}
K' = X^T A' X
Λ' = I ∘ K'
X' = X (F ∘ K')

After simplification, their contribution wrt degeneracy amounts to a simple modification to the matrix F:

F_ij = {
    0                if i = j
    0                if λ_i = λ_j and K'_ij = 0
    inv(λ_j - λ_i)   otherwise
}
When `M=I` and `M'=0`, their JVP simplifies to ``` K' = X^T A' X. # intermediate Λ' = I ∘ K' A Y' - Y'Λ = X (K' - (I ∘ K')) # solve for Y' using sylvester_solve X' = X (D ∘ (X^T Y') - X^T Y') ``` But we can take this further by multiplying both sides of Sylvester's equation by `X^T`. Then we have ``` let W' = X^T Y' and Z' = K' - (I ∘ K') Λ W' - W' Λ = Z' # sylvester's equation, but now the two matrices on left-hand side are diagonal λ_i W'_ij - W'_ij λ_j = Z'_ij # the same equation in scalar form W'_ij (λ_i - λ_j) = Z'_ij # rearranged to a scalar sylvester's equation ``` This Sylvester appears in the standard derivation, the same as in Giles and elsewhere. We can solve for `W'_ij = Z'_ij / (λ_i - λ_j)`, but this is what causes instability. Their approach does not handle this. It only handles the case where `λ_i - λ_j = 0`. In that case, for `Z'_ij≠0`, there are no solutions, and overflow is unavoidable. If `Z'_ij=0`, then `W'_ij` can be any finite number, and their approach would choose `W'_ij=0` (this is what the elementwise product with `D` does).

So as far as I can tell, this does not resolve the issue of degeneracy, especially since even if a matrix is constructed to have exactly equal eigenvalues, due to floating point error, the eigenvalues will usually be nonequal. All it does is ensure that for certain programs, when a matrix with exact eigenvalues is factorized to produce exact eigenvalues, and when the tangent has a certain structure, that 0/0 does not happen. Which is useful but I think will only happen in very extreme and rare cases.

One last point is that this modification applies to the standard eigendecomposition as well, not just for symmetric or real matrices.

oxinabox commented 3 years ago

I recently stumbled across this paper, which seems to provide exactly the algorithm we need here:

I noticed the author of the paper, @mfkasim1 is on GitHub, and might be interested in this discussion.

mfkasim1 commented 3 years ago

Thanks for the tag. The same issue is also discussed in pytorch: https://github.com/pytorch/pytorch/issues/47599 with PR https://github.com/pytorch/pytorch/pull/50942. I also implemented the algorithm in my library, xitorch, starting from this line: https://github.com/xitorch/xitorch/blob/a8ce4bf234fe839682586b7dc56f869ee1dc51d3/xitorch/linalg/symeig.py#L324 or this line for a simple eigendecomposition: https://github.com/xitorch/xitorch/blob/a8ce4bf234fe839682586b7dc56f869ee1dc51d3/xitorch/_impls/linalg/symeig.py#L47

The algorithm in the paper basically only works if the loss function does not depend directly on the degenerate eigenvectors, but it can depends on the space spanned by the degenerate eigenvectors. Therefore, additional checks are necessary (eq. 2.8 for forward diff, 2.13-2.15 for backward diff). If the additional checks are passed, then it also works for near-degenerate case. However, I haven't thought about complex and non-symmetric matrix, so it might not work.

I have tried the algorithm in my differentiable density functional theory (DFT) simulation (there are a lot of degenerate eigenvalues) and it works nicely (i.e. it passes pytorch's gradcheck and gradgradcheck).

shoyer commented 3 years ago

@mfkasim1 thanks for joining us!

I share @sethaxen's concern about almost degenerate eigenvalues. In such cases, the standard auto-diff rules for eigh can be numerically unstable, because it includes terms like 1/(λ_j - λ_i). Have you thought about this case, or perhaps it does not arise for your applications in density functional theory?

mfkasim1 commented 3 years ago

@shoyer The case in my DFT application is that it is supposed to have exactly the same eigenvalues theoretically, but numerically, the retrieved eigenvalues are only close to each other (near-degenerate). In the near-degenerate case, the denominator (λ_j - λ_i) is close to 0, so unless you have near 0 nominator (i.e. eqs 2.8 and 2.13-2.15), then the numerical instability is unavoidable (just like 1/x). In the paper, I just consider the case where the nominator is supposed to be 0. If it's not 0 (due to numerical error), it is assumed to be 0. If you want to consider a case where the nominator is supposed to be a very small value but not 0, then this is not covered in the paper.

In my DFT case, this is sufficient, because the loss function does not depend directly on the degenerate eigenvectors (it depends on the space spanned by the eigenvectors), so the nominator (eq 2.13) is supposed to be 0.

shoyer commented 3 years ago

I think it is fine to restrict the autodiff rules for eigh to the case where the output is invariant to the choice for basis within the degenerate subspace. If this isn't true, then the calculation isn't a well defined function (and would differ for different Lapack implementations).

platawiec commented 3 years ago

I came across this conversation and wanted to leave this note as a reference for near-degenerate eigenvectors, specifically the transformation of Eq. 10 into Eq. 11 to account for near-degeneracy: https://github.com/mitmath/18335/blob/master/notes/adjoint/eigenvalue-adjoint.pdf . Hopefully you find it useful, though some translation may be in order.

shoyer commented 3 years ago

(posting a comment from last week, that I thought I had already already submitted!)

Shouldn't be too hard. The JVP for a Sylvester solve is just another Sylvester solve. If X = sylvester_solve(A, B, Q), then Ẋ = sylvester_solve(A, B, Q̇ - Ȧ@X - X@Ḃ) (the solve is linear wrt its 3rd argument, so this is fine).

Yes, indeed, this does look pretty straightforward!

(Note for Googlers: here's a chat thread that contains links to a prototype for a differentiable sylvester solve in JAX, written in terms of SciPy operations)

But Sylvester solves usually compute a Schur decomposition of A and B and then use a triangular Sylvester solve like LAPACK's trsyl, so this JVP would compute the Schur decompositions twice. It's more efficient to reuse the Schur decomposition, but I imagine in JAX that could pose problems with higher order AD? Or does JAX have a traceable Schur decomposition?

I think we can use the same trick we use for linear solves: compute the matrix factorization/decompositions first (without gradients), and then pass it into the auto-diff primitive: https://github.com/google/jax/blob/6e1cd395e890b122008af2384b33c6aaa5374c28/jax/_src/lax/linalg.py#L262-L269

Technically, the auto-diff primitive becomes "Sylvester solve from a Schur decomposition" rather than "Sylvester solve from scratch".

LionSR commented 2 years ago

(posting a comment from last week, that I thought I had already already submitted!)

Shouldn't be too hard. The JVP for a Sylvester solve is just another Sylvester solve. If

X = sylvester_solve(A, B, Q), then Ẋ = sylvester_solve(A, B, Q̇ - Ȧ@X - X@Ḃ) (the solve is linear wrt its 3rd argument, so this is fine).

Yes, indeed, this does look pretty straightforward!

(Note for Googlers: here's a chat thread that contains links to a prototype for a differentiable sylvester solve in JAX, written in terms of SciPy operations)

But Sylvester solves usually compute a Schur decomposition of A and B and then use a triangular Sylvester solve like LAPACK's trsyl, so this JVP would compute the Schur decompositions twice. It's more efficient to reuse the Schur decomposition, but I imagine in JAX that could pose problems with higher order AD? Or does JAX have a traceable Schur decomposition?

I think we can use the same trick we use for linear solves: compute the matrix factorization/decompositions first (without gradients), and then pass it into the auto-diff primitive:

https://github.com/google/jax/blob/6e1cd395e890b122008af2384b33c6aaa5374c28/jax/_src/lax/linalg.py#L262-L269

Technically, the auto-diff primitive becomes "Sylvester solve from a Schur decomposition" rather than "Sylvester solve from scratch".

Hi! I was wondering if there are any plans to implement the Sylvester solver in Jax, perhaps based on the draft implementation you mentioned or Jax.scipy.linalg.schur and jax.scipy.linalg.solve_triangular as in the Bartels-Stewart algorithm? I would be interested in an implementation. Any tips or status update would be appreciated!

HHalva commented 1 year ago

Any progress on this? Encountering problems with almost degenerate eigenvalues...

yonghakim commented 1 year ago

Hi, there is regularization technique to circumvent this degeneracy. You can refer those in docstring. Am not an expert at math, but this seems to work so far.

code from my repo (https://github.com/kc-ml2/meent/blob/main/meent/on_jax/emsolver/primitives.py)

import jax
import jax.numpy as jnp

from functools import partial

def conj(arr):
    return arr.real + arr.imag * -1j
    # return arr.conj()

@partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3))
def eig(x, type_complex=jnp.complex128, perturbation=1E-10, device='cpu'):

    _eig = jax.jit(jnp.linalg.eig, device=jax.devices('cpu')[0])

    eigenvalues_shape = jax.ShapeDtypeStruct(x.shape[:-1], type_complex)
    eigenvectors_shape = jax.ShapeDtypeStruct(x.shape, type_complex)

    result_shape_dtype = (eigenvalues_shape, eigenvectors_shape)
    if device == 'cpu':
        res = _eig(x)
    else:
        res = jax.pure_callback(_eig, result_shape_dtype, x)

    return res

def eig_fwd(x, type_complex, perturbation, device):
    return eig(x, type_complex, perturbation), (eig(x, type_complex, perturbation), x)

def eig_bwd(type_complex, perturbation, device, res, g):
    """
    Gradient of a general square (complex valued) matrix
    Eq 2~5 in https://www.nature.com/articles/s42005-021-00568-6
    Eq 4.77 in https://arxiv.org/pdf/1701.00392.pdf
    Eq. 30~32 in https://www.sciencedirect.com/science/article/abs/pii/S0010465522002715
    https://github.com/kch3782/torcwa
    https://github.com/weiliangjinca/grcwa
    https://github.com/pytorch/pytorch/issues/41857
    https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation
    https://discuss.pytorch.org/t/autograd-on-complex-numbers/144687/3
    """

    (eig_val, eig_vector), x = res
    grad_eigval, grad_eigvec = g

    grad_eigval = jnp.diag(grad_eigval)
    W_H = eig_vector.T.conj()

    Fij = eig_val.reshape((1, -1)) - eig_val.reshape((-1, 1))
    Fij = Fij / (jnp.abs(Fij) ** 2 + perturbation)
    Fij = Fij.at[jnp.diag_indices_from(Fij)].set(0)

    # diag_indices = jnp.arange(len(eig_val))
    # Eij = eig_val.reshape((1, -1)) - eig_val.reshape((-1, 1))
    # Eij = Eij.at[diag_indices, diag_indices].set(1)
    # Fij = 1 / Eij
    # Fij = Fij.at[diag_indices, diag_indices].set(0)

    grad = jnp.linalg.inv(W_H) @ (grad_eigval.conj() + Fij * (W_H @ grad_eigvec.conj())) @ W_H
    grad = grad.conj()
    if not jnp.iscomplexobj(x):
        grad = grad.real

    return grad,

eig.defvjp(eig_fwd, eig_bwd)
miaoL1 commented 7 months ago

Hi, there is regularization technique to circumvent this degeneracy. You can refer those in docstring. Am not an expert at math, but this seems to work so far.

code from my repo (https://github.com/kc-ml2/meent/blob/main/meent/on_jax/emsolver/primitives.py)

import jax
import jax.numpy as jnp

from functools import partial

def conj(arr):
    return arr.real + arr.imag * -1j
    # return arr.conj()

@partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3))
def eig(x, type_complex=jnp.complex128, perturbation=1E-10, device='cpu'):

    _eig = jax.jit(jnp.linalg.eig, device=jax.devices('cpu')[0])

    eigenvalues_shape = jax.ShapeDtypeStruct(x.shape[:-1], type_complex)
    eigenvectors_shape = jax.ShapeDtypeStruct(x.shape, type_complex)

    result_shape_dtype = (eigenvalues_shape, eigenvectors_shape)
    if device == 'cpu':
        res = _eig(x)
    else:
        res = jax.pure_callback(_eig, result_shape_dtype, x)

    return res

def eig_fwd(x, type_complex, perturbation, device):
    return eig(x, type_complex, perturbation), (eig(x, type_complex, perturbation), x)

def eig_bwd(type_complex, perturbation, device, res, g):
    """
    Gradient of a general square (complex valued) matrix
    Eq 2~5 in https://www.nature.com/articles/s42005-021-00568-6
    Eq 4.77 in https://arxiv.org/pdf/1701.00392.pdf
    Eq. 30~32 in https://www.sciencedirect.com/science/article/abs/pii/S0010465522002715
    https://github.com/kch3782/torcwa
    https://github.com/weiliangjinca/grcwa
    https://github.com/pytorch/pytorch/issues/41857
    https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation
    https://discuss.pytorch.org/t/autograd-on-complex-numbers/144687/3
    """

    (eig_val, eig_vector), x = res
    grad_eigval, grad_eigvec = g

    grad_eigval = jnp.diag(grad_eigval)
    W_H = eig_vector.T.conj()

    Fij = eig_val.reshape((1, -1)) - eig_val.reshape((-1, 1))
    Fij = Fij / (jnp.abs(Fij) ** 2 + perturbation)
    Fij = Fij.at[jnp.diag_indices_from(Fij)].set(0)

    # diag_indices = jnp.arange(len(eig_val))
    # Eij = eig_val.reshape((1, -1)) - eig_val.reshape((-1, 1))
    # Eij = Eij.at[diag_indices, diag_indices].set(1)
    # Fij = 1 / Eij
    # Fij = Fij.at[diag_indices, diag_indices].set(0)

    grad = jnp.linalg.inv(W_H) @ (grad_eigval.conj() + Fij * (W_H @ grad_eigvec.conj())) @ W_H
    grad = grad.conj()
    if not jnp.iscomplexobj(x):
        grad = grad.real

    return grad,

eig.defvjp(eig_fwd, eig_bwd)

This code works fine! BTW, I'm pondering is there any "jvp"-version. (Cause I'm working on a frame only supporting forward-auto-diff)

EricaCMitchell commented 5 months ago

I am also running into issues with needing degenerate eigenvalues for the computation of molecular energies using wavefunction methods. Degenerate eigenvalues are quite common in quantum chemistry as evidenced by Kasim's work on differentiable density functional theory and if you just think of the symmetry present in many molecules.

Is the fact that backward-mode AD can't handle degenerate eigenvalues a hindrance to implementing the forward-mode JVP?