jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30k stars 2.75k forks source link

autodiff support for jax.numpy.linalg.eig #2748

Open froystig opened 4 years ago

froystig commented 4 years ago

Note that eigh is already taken care of.

shoyer commented 4 years ago

I'm pretty sure we could achieve this with a minor variation of our existing JVP rule for eigh, replacing U.T.conj() -> inv(U) (of course it should really use an LU solve rather than computing the inverse directly).

ianwilliamson commented 4 years ago

Just wanted to throw in a +1 for wanting this to be implemented.

j-towns commented 4 years ago

@shoyer do you have a reference for it? I've just been working through the math by hand and it seems what you said is correct, except that you have to do a slightly awkward correction to ensure that dU.T @ U has ones down the diagonal (which I think is required - this comes from the constraint that the eigenvectors are normalized). Anyway I think I will draft an implementation today.

Edit: It's implemented in Autograd https://github.com/HIPS/autograd/blob/master/autograd/numpy/linalg.py#L152-L173, with a reference to https://arxiv.org/pdf/1701.00392.pdf, eq 4.77.

Edit 2: The jvp equations in that paper are 4.60 and 4.63, but I think 4.63 (the jvp for the eigenvectors) is wrong. The statement above 4.63 ("...can not influence the amplitude of the eigenvectors...") is correct but I don't think they translated that constraint correctly into math. I've tried implementing their version, and my own, neither are working yet so not 100% sure whether I'm right about this.

j-towns commented 4 years ago

Also @shoyer how should I do a solve (inv(a) @ b for square matrices matrices a and b)? I think I can't use jax.numpy.linalg.solve from jax.lax_linalg because of circular dependency.

For now I'll use inv as you suggested above.

shoyer commented 4 years ago

Section 3.1 from this reference in a comment under eigh_jvp_rule (in lax_linalg.py) works through the general case of how to calculate eigenvector derivative: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf

shoyer commented 4 years ago

Also @shoyer how should I do a solve (inv(a) @ b for square matrices matrices a and b)? I think I can't use jax.numpy.linalg.solve from jax.lax_linalg because of circular dependency.

My suggestion would be to use a local import, in the JVP rule, e.g.,

def eig_jvp_rule(...):
  from jax.numpy.linalg import solve
  ...

You could also try refactoring, but this the usual hack for circular dependency challenges.

hawkinsp commented 4 years ago

I'd be tempted to at least try the refactoring of moving the guts of solve into lax_linalg.py.

j-towns commented 4 years ago

Cool, I've moved it in the draft pr, wasn't too bad to do. Still getting incorrect values for the eig derivatives though 😔.

j-towns commented 4 years ago

The JVP seems to be correct now that I've relaxed the test tolerance slightly, but the VJPs are way out, I'm not sure why that is yet.

j-towns commented 4 years ago

I notice that testEighGrad is currently skipped because 'Test fails with numeric errors', I wonder if the problems I'm seeing are related, since the eig jvp is mostly copied from the eigh jvp.

j-towns commented 4 years ago

OK I think I know why what I have is incorrect. Eigenvectors are only unique up to (complex) scalar multiple. The eigenvectors returned by numpy.linalg.eig are normalized so that they have length 1 (I already knew this), and also so that their largest component is real (see http://www.netlib.org/lapack/lapack-3.1.1/html/dgeev.f.html). That constraint I was not previously aware of and I think it might take some work to correct the derivations + implementation that I have.

j-towns commented 4 years ago

A similar issue might also explain why the derivative tests for eigh are failing - the eigenvectors are normalized so they have length 1, but are still only unique up to multiplication by a complex scalar whose absolute value is 1 (i.e. there is one degree of freedom per eigenvector). It's not clear from the low level eigh docs (http://www.netlib.org/lapack/lapack-3.1.1/html/zheevd.f.html) how this non-uniqueness is addressed.

Edit: just running np.linalg.eigh on a couple of inputs it looks like the eigenvectors are normalized so that the first component of each is real. It seems a bit strange that eigh uses a different convention to eig, and this means that you'll get np.linalg.eigh(x) != np.linalg.eig(x) for complex, hermitian x. The eigh convention should be easier to differentiate, and maybe we should change our eig_p primitive to match the eigh convention, so that lax_linalg.eig(x) == lax_linalg.eigh(x) for all hermitian x.

shoyer commented 4 years ago

We could certainly pick a new convention for normalizing vectors from eig if that makes it easier to differentiate. The downside is that this would probably require a bit more computation. If it's only O(n^2) time, I would say definitely go for it, maybe more questionable if we need dense matrix/matrix multiplication which O(n^3). In the later case we might add an optional argument for making eig differentiate.

For what it's worth, I have feeling that the right answer for how to differentiate eig/eigh in most cases is don't, precisely because eigen-decomposition is often not a well defined function. The right function to differentiate is something downstream of eigen-decomposition where the outputs of the numerical method become a well defined function, e.g., the result of a matrix power series calculation. If we can characterize the full set of such "well defined functions of eigen-decompositions" then perhaps those are the right primitives for which to define auto-diff rules.

j-towns commented 4 years ago

Yeah I agree. It would be very weird and likely a bug if a user implemented a function that depended on the length of an eigenvector, since the normalization is essentially an implementation detail. Catering for these design decisions with correct derivatives is also really awkward, so maybe we should indeed look for another level at which to provide derivatives.

j-towns commented 4 years ago

@ianwilliamson do you have a use case for eig derivatives? Would be useful to know a bit about it.

j-towns commented 4 years ago

Also I think it is reasonable to support eigh of a real symmetric matrix, where there is a pretty obvious and straightforward unique value and derivative.

shoyer commented 4 years ago

Also I think it is reasonable to support eigh of a real symmetric matrix, where there is a pretty obvious and straightforward unique value and derivative.

Even eigh is only uniquely defined if the eigenvalues are unique. If degeneracies are valid (common in many applications) it isn't a well defined function.

momchilmm commented 4 years ago

We had a related discussion when I was fixing eigh in autograd: https://github.com/HIPS/autograd/pull/527

Essentially, the vjp there works for objective functions that do not depend on the arbitrary phase of the eigenvectors (the "gauge"), and the tests are written for such functions. This is because in a general solver this phase is just arbitrary, so even finite-difference derivatives won't work, i.e. eig(X) and eig(X + dX) can spit out eigenvectors with arbitrary phase difference. It sounds like in jax you are actually setting the gauge (largest element to be real), so you could try to make the vjp account for that and match the finite-difference derivative under that gauge, but I think you can't really expect the user to know that you're doing that. Meaning that if I'm a user and a have a function that depends on the phase of an eigenvector, the correct way to do it is to manually set the gauge to whatever I want it to be, in a way tracked by jax. Or in other words: you can first get the vjp to work for gauge-independent functions, and then add the normalization on top of that.

The problem with degeneracies is harder. In one of my packages, I purposefully add small noise to avoid symmetries that result in degeneracies, but that's obviously a workaround. Here's a paper that could provide some indication on how this could be handled ("generalized gradient"), but I don't really understand it well: https://oaktrust.library.tamu.edu/bitstream/handle/1969.1/184232/document-1.pdf?sequence=1

hmusgrave commented 4 years ago

@ianwilliamson do you have a use case for eig derivatives? Would be useful to know a bit about it.

Just in case this is helpful, I've never actually had a use case for eig derivatives per se since I've always had Hermitian matrices available, but the last time I reached for an eigh derivative it was because I needed to find a representative set of inputs yielding a singular Jacobian (a determinant derivative would have worked fine, but that was a bit slower and more unstable iirc -- I stopped searching when eigh derivatives were good enough). The scipy.linalg package was more helpful to me than the numpy wrapper because of its ability to single out a range of eigenvalues.

Most natural uses of an eig derivative I think would follow a similar pattern of having a deterministic scheme for choosing a particular eigenvalue (smallest magnitude, largest real part, etc) that relates to the problem being studied, or perhaps as inputs to a symmetric function.

I know you asked this in the context of eigenvector normalization, and fwiw I've always had to normalize them myself in whichever way suites the current problem and have never needed their derivatives except to compute higher-order derivatives of eigenvalues. Sorry I can't be more help there.

The problem with degeneracies is harder. In one of my packages, I purposefully add small noise to avoid symmetries that result in degeneracies, but that's obviously a workaround. Here's a paper that could provide some indication on how this could be handled ("generalized gradient"), but I don't really understand it well: https://oaktrust.library.tamu.edu/bitstream/handle/1969.1/184232/document-1.pdf?sequence=1

Jax doesn't support subgradients at all does it? E.g., grad(abs)(0.)==1 even though the subdifferential there is the entire closed interval [-1, 1].

momchilmm commented 4 years ago

Jax doesn't support subgradients at all does it? E.g., grad(abs)(0.)==1 even though the subdifferential there is the entire closed interval [-1, 1].

Ohh I see what this is about. Yeah I wouldn't expect this to be something that will be supported in jax. By the way, #3112 and #3114 might also be of interest to you.

LuukCoopmans commented 4 years ago

Hello, just wondering what the status is of the implementation of the np.linalg.eig function in JAX? I am working as quantum physicist and really like the JAX library, I successfully used it for an optimization problem involving the the eigh function in a previous project. For a new project however I am dealing with non-hermitian matrices so I require the eig function.

j-towns commented 4 years ago

@LuukCoopmans np.linalg.eig is implemented but its derivatives are not. Do you need to be able to differentiate eig?

LuukCoopmans commented 4 years ago

@j-towns yes I need to be able differentiate it.

j-towns commented 4 years ago

Cool, as you can see in the comments above, the derivative for the eigen-vectors is quite awkward to get right because they’re only defined up to ‘gauge’ (that is up to multiplication by a complex scalar with absolute value 1).

@LuukCoopmans sorry to keep quizzing you, but does your objective function depend on the whole output of eig or just on the eigenvalues? The latter might be easier to support.

In the short term you might be interested in using JAX’s custom_jvp and custom_vjp for implementing your own workarounds where we haven’t managed to implement derivatives, like in this case.

LuukCoopmans commented 4 years ago

@j-towns actually I find this an interesting problem, in physics the quantum wavefunction (an eigenvector) is always defined up to a 'gauge' the same way as you describe. However, this gauge is usually not important for the calculation of interested quantities (expectation values) because it gets multiplied out, like say O is a matrix and v is the eigenvector of some other matrix O' then we are interested in quantities v.T.conj()Ov. Also in my case I eventually take the absolute value squared of the eigenvector so the phase is not important. I can however see that for the derivative this might give a problem, because the gauge on the eigenvector and the derivative can come back different if I am correct?

hmusgrave commented 4 years ago

However, this gauge is usually not important for the calculation of interested quantities (expectation values) because it gets multiplied out

This is my experience as well. I've only ever needed eigenvector derivatives in scenarios where the gauge didn't matter to the final calculation. I usually did need a particular magnitude, e.g. normalizing to |v|=1; jax does however easily support differentiating that normalization step, so I'm not sure that really matters for an eig() derivative.

Also in my case I eventually take the absolute value squared of the eigenvector so the phase is not important.

This quantity isn't uniquely defined either, and it's similar to the gauge problem. Eigenvectors are only unique up to a non-zero constant multiple from the relevant field.

shoyer commented 4 years ago

As I noted above in https://github.com/google/jax/issues/2748#issuecomment-627444706, I think np.linalg.eig is rarely the right level at which to calculate derivatives. We have conventions for how to pick the gauge for calculations, but those aren't necessarily consistent with derivatives. I think the problem of calculating reverse mode derivatives of eig may be fundamentally undefined from a mathematical perspective -- there does not necessarily exist a single choice of gauge for which the eig function is entirely continuous.

Instead, we need higher level auto-diff primitives, corresponding to well defined functions that are invariant of gauge. For example, we can calculate derivatives for any matrix-valued function of a Hermitian matrix (see https://github.com/FluxML/Zygote.jl/pull/355). We could add helper functions for calculating these sorts of things, ideally with support for calculating the underlying functions in flexible ways (e.g., using eig internally).

hmusgrave commented 4 years ago

Instead, we need higher level auto-diff primitives, corresponding to well defined functions that are invariant of gauge.

That makes sense. Do you think it's still reasonable to support eigenvalue derivatives except on the measure-zero sets where they don't exist (either raising an error or providing a default value in such cases, sort of like how abs is handled)?

shoyer commented 4 years ago

Do you think it's still reasonable to support eigenvalue derivatives except on the measure-zero sets where they don't exist (either raising an error or providing a default value in such cases, sort of like how abs is handled)?

Yes, absolutely!

This is basically what we do currently for eigh. If there are degeneracies, then the derivative with respect to the eigenvectors will be all NaN.

j-towns commented 4 years ago

Is there a straightforward way for us to provide eigenvalue derivatives without providing eigenvector derivatives (since this gauge issue only affects evectors afaict)? Do you think we ought to have a primitive which only returns eigenvalues?

nikikilbertus commented 3 years ago

Is there a straightforward way for us to provide eigenvalue derivatives without providing eigenvector derivatives (since this gauge issue only affects evectors afaict)? Do you think we ought to have a primitive which only returns eigenvalues?

:+1: That would definitely solve my problem!

I'm working on a project in which we would like to compute gradients for a function that depends on eigenvalues of non-hermitian matrices (but not eigenvectors). From what I understand, the difficulty lies in computing gradients for the eigenvectors of eig due to ambiguity in the phase. Would it be possible to implement gradients only for eigvals first (which internally calls eig without computing eigenvectors and only returns the eigenvalues.)

I think this would already cover a large fraction of applications in theoretical/mathematical physics.

j-towns commented 3 years ago

Hey @nikikilbertus, this has become straightforward since I wrote that comment because the JAX eig primitive now has kwargs to turn off the computation of eigenvectors. https://github.com/google/jax/pull/4941 should do the trick for you, just make sure your objective function uses jax.numpy.linalg.eigvals, rather than jax.numpy.linalg.eig.

Note that unfortunately second (and higher) derivatives aren't supported, I hope that's good enough to get you somewhere. If you need second derivs that should be possible but might be a bit more tricky to implement.

nikikilbertus commented 3 years ago

Thanks so much @j-towns this is great! 👏

Randl commented 3 years ago

There is a good post on gauge problem: https://re-ra.xyz/Gauge-Problem-in-Automatic-Differentiation/ Also related discussion in tensorflow https://github.com/tensorflow/tensorflow/pull/33808

benj252 commented 3 years ago

Hi, I am trying to get the derivatives of the eigenvalue decomposition of a matrix. At the moment I am trying to very simply have a function as follows:

def matrixEigs(a1):
    arr = jax.numpy.array([[a1, 1], [1, 2]])
    eigs = jax.numpy.linalg.eigvals(arr)
    return eigs

eigGrad = grad(matrixEigs)

gradients = eigGrad(1.)

however I get the error "NotImplementedError": Forward-mode differentiation rule for 'eig' not implemented?

j-towns commented 3 years ago

Please make sure that you're using a recent version of JAX. When I run

import jax
from jax import grad

def matrixEigs(a1):
    arr = jax.numpy.array([[a1, 1], [1, 2]])
    eigs = jax.numpy.linalg.eigvals(arr)
    return eigs

eigGrad = grad(matrixEigs)

gradients = eigGrad(1.)

I get

TypeError: Gradient only defined for scalar-output functions. Output had shape: (2,).
benj252 commented 3 years ago

Hi,

Thank you for the help before! I have re-installed JAX and I am now struggling to get the gradients of the imaginary parts of the eigenvalues. The following code:

def matrixEigs(a1):
    arr = jax.numpy.array([[a1, -10.], [1., 2.]])
    eigs = jax.numpy.linalg.eigvals(arr)
    return eigs[0]

eigGrad = grad(matrixEigs)
gradientsReal = eigGrad(1.)

Returns the error 'TypeError: grad requires real-valued outputs (output dtype that is a sub-dtype of np.floating), but got complex64. For holomorphic differentiation, pass holomorphic=True. For differentiation of non-holomorphic functions involving complex outputs, or function with integer outputs, use jax.vjp directly'.

When I change the line where the gradient of matrixEigs is to:

eigGrad = grad(matrixEigs, holomorphic=True)

This returns the error: 'TypeError: grad with holomorphic=True requires inputs with complex dtype, but got float32.'

I can get the gradients of the real part fine by running the following:

def matrixEigsReal(a1):
    arr = jax.numpy.array([[a1, -10.], [1., 2.]])
    eigs = jax.numpy.linalg.eigvals(arr)
    return eigs[0].real

eigGradReal = grad(matrixEigsReal)
gradientsReal = eigGradReal(1.)

Please may you advise me on how to get the gradients of the imaginary part?

jakevdp commented 3 years ago

If you use holomorphic=True, you could convert your input to complex without changing its value:

eigGrad = grad(matrixEigs, holomorphic=True)
gradientsReal = eigGrad(complex(1.))

does that give you what you want?

momchilmm commented 3 years ago

I would also be a bit apprehensive if there are any implications to using holomorphic=True that I do not understand (maybe there aren't), and, just to be on the safe side, would rather compute the gradient of the real and imaginary part separately. You've already done half the job in your last function, you could just also return eigs[0].imag and get the two gradients.

mattjj commented 3 years ago

I haven't read this thread yet, but I just wanted to chime in about holomorphic=True: all it does is disables some error messages (and enables others). To understand why, I recommend checking out the autodiff cookbook.

momchilmm commented 3 years ago

Ah, great, didn't know this has been expanded on in such detail. But so it does seem to me that holomorphic=True in this particular case may not be the correct approach. The issue seems to be getting the derivative of the complex-valued eigenvalue w.r.t. a real-valued input, i.e. f: R -> C. If we know the function is indeed holomorphic, then this would work, but otherwise it will return the complex grad(f.real) evaluated at x + 0j, as far as I understand.

That's why I was saying I'd probably rather do something like grad_f = grad(f.real) + 1j*grad(f.imag) to avoid unexpected results.

gboehl commented 2 years ago

I have a usecase where the derivative of the eigenvectors would be really handy.

I am looking for a stationary distribution for a given transition matrix:

def stationary_distribution(T):
    """Find invariant distribution of a Markov chain by unit eigenvector.
    """

    v, w = jnp.linalg.eig(T)

    # using sorted args instead of np.isclose is necessary for jax-jitting
    args = jnp.argsort(v)
    unit_ev = w[:, args[-1]]

    return unit_ev.real / unit_ev.real.sum()

I can do this by iteration, but this is quite costly when calculating the jacobian:

# transition matrix
tmat = jnp.array(((.9,.1),(.8,.2))).T
# bruite force approach
D0 = jnp.linalg.matrix_power(tmat, 99) @ jnp.ones(2)*.5
# using eigenvector
D1 = stationary_distribution(tmat)
# ensure this is the samme
assert jnp.allclose(D0,D1) # passes

I do however need the Jacobian later to find a fixpoint of a larger function that involves this stationary distribution (of course, for far larger transition matrices). Since the staionary distribution is unique, the solution should be unique.

Any chance for this to get implemented, or any suggestions for a good workaround?

Thanks for all this!

hawkinsp commented 2 years ago

I'm not an expert on the math (see the discussion above), but the paper cited by @j-towns above (https://arxiv.org/pdf/1701.00392.pdf) has a derivation for the forward gradient of the eigenvectors (equation 4.63) under some assumptions. Try it out via a custom gradient?

gboehl commented 2 years ago

Thanks for the response!

Just in case anyone else runs into the same problem (looking for an AD'able way to find a stationary distribution): I used the following implementation I found a while ago:

import jax.numpy as jnp
import jax.lax.linalg as lax_linalg
from jax import custom_jvp
from functools import partial

from jax import lax
from jax.numpy.linalg import solve

@custom_jvp
def eig(a):
    w, vl, vr = lax_linalg.eig(a)
    return w, vr

@eig.defjvp
def eig_jvp_rule(primals, tangents):
    a, = primals
    da, = tangents

    w, v = eig(a)

    eye = jnp.eye(a.shape[-1], dtype=a.dtype)
    # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
    Fmat = (jnp.reciprocal(eye + w[..., jnp.newaxis, :] - w[..., jnp.newaxis])
            - eye)
    dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
                  precision=lax.Precision.HIGHEST)
    vinv_da_v = dot(solve(v, da), v)
    du = dot(v, jnp.multiply(Fmat, vinv_da_v))
    corrections = (jnp.conj(v) * du).sum(-2, keepdims=True)
    dv = du - v * corrections
    dw = jnp.diagonal(vinv_da_v, axis1=-2, axis2=-1)
    return (w, v), (dw, dv)

This gave me correct results, but (for matrices larger than 10x10) is way more costly than just having a lax.while_loop on dot-multiplying the transition matrices.

Randl commented 1 year ago

The implementation above appears to be correct and match TF implementation, any reasons in can't be merged?

crmaunder commented 1 year ago

I am trying to take the Hessian of a linear combination of the magnitudes of the roots of a characteristic equation using jax autodiff.

d/dx^2 (vdot(vec,abs(eigs(x))))

import jax.numpy as jnp
import jax
def jax_pole_zero_utility(coefs,a0):
    pz = jnp.roots(jnp.concatenate((jnp.array([a0]),coefs)))
    return jnp.abs(pz)

jax_pole_zero_hessian = jax.jacrev(jax.jacfwd(lambda params,a0,lagrange: jnp.vdot(lagrange,jax_pole_zero_utility(params,a0)),argnums=(0)),argnums=(0))
x = jnp.array([1,2,2],'float')
l = jnp.array([3,-1,4],'float')
jax_pole_zero_hessian(x,-1,l)

This yields the following error:

*** NotImplementedError: The derivatives of eigenvectors are not implemented, only eigenvalues. See https://github.com/google/jax/issues/2748 for discussion.

Note that taking a single gradient works fine:

jax_pole_zero_grad = jax.grad(lambda params,a0,lagrange: jnp.vdot(lagrange,jax_pole_zero_utility(params,a0)),argnums=(0))
jax_pole_zero_grad(x,-1,l)

All combinations of jax.jacfwd and jax.jacrev produce the same error. I don't think I am requesting the eigenvectors, so I am not sure why I am receiving this error, but I'd appreciate any pointers for a workaround.

j-towns commented 1 year ago

As mentioned in https://github.com/google/jax/issues/2748#issuecomment-729623959, second derivatives of eigenvalues aren’t supported yet. That’s because the computation of the first derivative of eigenvals depends on the eigenvectors, and therefore automatically differentiating it (to get the second derivative) fails, because we don’t have first derivatives of eigenvectors.

I think it should be possible to get this working, but I haven’t thought about it too hard yet.

j-towns commented 1 year ago

Specifically, it’s the dependence on v on this line that causes the problem.

SichengHe commented 1 year ago

We can use the adjoint method to compute the eigenvector derivatives (weighted) accurately (see Eigenvalue problem derivatives computation for a complex matrix using the adjoint method) and we can approximate them (see Derivatives for Eigenvalues and Eigenvectors via Analytic Reverse Algorithmic Differentiation). Entry-wise eigenvector derivative is still expensive to compute (d v[i] / d A).

shoyer

Also, we extended the dot-product trick developed by Giles to derive reverse form derivative using the forward form from the real functions to complex analytic functions (the first paper).

rcarson3 commented 1 year ago

I will just through out another use case where it would be nice to see this supported. I work in the computational solid mechanics field, and I needed support for the eigenvectors to calculate the hessian defined in this paper: http://dx.doi.org/10.1016/j.cma.2016.11.026 . I was able to get things working thanks to @gboehl workaround and the results seem to align with my hand-rolled implementations of the hessian at least for the test cases I've checked against. However, it would be nice to have a native solution in Jax as well.