ThummeTo / DifferentiableEigen.jl

The current implementation of `LinearAlgebra.eigen` does not support sensitivities. DifferentiableEigen.jl offers an `eigen` function that is differentiable by every AD-framework with support for ChainRulesCore.jl or ForwardDiff.jl.
MIT License
26 stars 1 forks source link

Eigen with complex output #7

Open oameye opened 5 months ago

oameye commented 5 months ago

The implementation taken from here, has complex number output.

using LinearAlgebra
using ForwardDiff

function make_eigen_dual(val::Real, partial)
    ForwardDiff.Dual{ForwardDiff.tagtype(partial)}(val, partial.partials)
end

function make_eigen_dual(val::Complex, partial::Complex)
    Complex(ForwardDiff.Dual{ForwardDiff.tagtype(real(partial))}(real(val), real(partial).partials),
        ForwardDiff.Dual{ForwardDiff.tagtype(imag(partial))}(imag(val), imag(partial).partials))
end

function eigen(A::StridedMatrix{<:ForwardDiff.Dual})
    A_values = map(d -> d.value, A)
    A_values_eig = eigen(A_values)
    UinvAU = A_values_eig.vectors \ A * A_values_eig.vectors
    vals_diff = diag(UinvAU)
    F = similar(A_values, eltype(A_values_eig.values))
    for i ∈ axes(A_values, 1), j ∈ axes(A_values, 2)
        if i == j
            F[i, j] = 0
        else
            F[i, j] = inv(A_values_eig.values[j] - A_values_eig.values[i])
        end
    end
    vectors_diff = A_values_eig.vectors * (F .* UinvAU)
    for i ∈ eachindex(vectors_diff)
        vectors_diff[i] = make_eigen_dual(A_values_eig.vectors[i], vectors_diff[i])
    end
    return Eigen(vals_diff, vectors_diff)
end

Code can be tested with:

 function f(A)
    A_eig = eigen(A)
    return A_eig.vectors * Diagonal(A_eig.values) / A_eig.vectors
end

ForwardDiff.derivative(t -> f([t+1 t; t t+1]), 1.0)
ForwardDiff.jacobian(x -> f(x), [1.0 2.0; 4.0 1.0])

testfun1(A) = sum(abs2, eigen(A).values)
testfun2(A) = sum(abs2, eigvals(A))
A = randn(40,40)
testfun1(A)
testfun2(A)
ThummeTo commented 5 months ago

Thanks for the info. I would totally agree on switching to an implementation with complex output, but it should work with at least ForwardDiff, ReverseDiff and Zygote. Can you please check for that?

oameye commented 4 months ago

I cannot seem to let it work with ReverseDiff yet.

However, Zygote has native eigen support:

using  Zygote, LinearAlgebra
 function test_eigen(x)
    A = [x[1] x[2]; x[3] 2.0]
    λ = eigen(A).values
    return sum(abs.(λ))
end
function test_eigvals(x)
     A = [x[1] x[2]; x[3] 2.0]
     λ = eigvals(A)
    return sum(abs.(λ))
end

x = [-4.0, -17.0, 2.0]
G1_zyg = Zygote.gradient(test_eigen, x)[1]
G2_zyg = Zygote.gradient(test_eigvals, x)[1]