JuliaSmoothOptimizers / Krylov.jl

A Julia Basket of Hand-Picked Krylov Methods
Other
335 stars 51 forks source link

Support for eigenvalues problem #673

Open albertomercurio opened 1 year ago

albertomercurio commented 1 year ago

Hi,

I think it could be great to introduce the support for eigenvalue solvers. I noticed that the implicit restarted method has already been passed. The modern faster method should be the Krylov-Schur method, which is already used in the KrylovKit.jl and ArnoldiMethods.jl packages.

A good reference should be this one.

Looking at the KrylovKit code, and using the ExponentialUtilities.jl arnoldi struct, I wrote the following code, which is way faster that implicit restarted method. And it support already the GPU arrays.

I opened an issue (https://github.com/JuliaLang/julia/issues/47291) to Julia for support the LAPACK hseqr! solver natively, but for the moment I wrote it by my own.

using MKL
using LinearAlgebra
using SparseArrays
using LinearMaps
using ExponentialUtilities
using BenchmarkTools
using Krylov # The master branch, actually
using IncompleteLU
using Arpack
using CUDA
using CUDA.CUSPARSE
CUDA.allowscalar(false)

using ExponentialUtilities: getV, getH, arnoldi_step!

import LinearAlgebra:
    BlasFloat,
    BlasInt,
    LAPACKException,
    DimensionMismatch,
    SingularException,
    PosDefException,
    chkstride1,
    checksquare
import LinearAlgebra.BLAS: @blasfunc, BlasReal, BlasComplex
import LinearAlgebra.LAPACK: chklapackerror
@static if VERSION >= v"1.7"
    const liblapack = LinearAlgebra.BLAS.libblastrampoline
else
    const liblapack = LinearAlgebra.LAPACK.liblapack
end
import Base: require_one_based_indexing

# Schur factorization of a Hessenberg matrix
permuteschur!(T::AbstractMatrix{<:BlasFloat}, p::AbstractVector{Int}) =
    permuteschur!(T, one(T), p)
function permuteschur!(
    T::AbstractMatrix{S},
    Q::AbstractMatrix{S},
    order::AbstractVector{Int}
) where {S<:BlasComplex}
    n = checksquare(T)
    p = collect(order) # makes copy cause will be overwritten
    @inbounds for i in 1:length(p)
        ifirst::BlasInt = p[i]
        ilast::BlasInt = i
        T, Q = LAPACK.trexc!(ifirst, ilast, T, Q)
        for k in (i+1):length(p)
            if p[k] < p[i]
                p[k] += 1
            end
        end
    end
    return T, Q
end

function permuteschur!(
    T::AbstractMatrix{S},
    Q::AbstractMatrix{S},
    order::AbstractVector{Int}
) where {S<:BlasReal}
    n = checksquare(T)
    p = collect(order) # makes copy cause will be overwritten
    i = 1
    @inbounds while i <= length(p)
        ifirst::BlasInt = p[i]
        ilast::BlasInt = i
        if ifirst == n || iszero(T[ifirst+1, ifirst])
            T, Q = LAPACK.trexc!(ifirst, ilast, T, Q)
            @inbounds for k in (i+1):length(p)
                if p[k] < p[i]
                    p[k] += 1
                end
            end
            i += 1
        else
            p[i+1] == ifirst + 1 ||
                error("cannot split 2x2 blocks when permuting schur decomposition")
            T, Q = LAPACK.trexc!(ifirst, ilast, T, Q)
            @inbounds for k in (i+2):length(p)
                if p[k] < p[i]
                    p[k] += 2
                end
            end
            i += 2
        end
    end
    return T, Q
end

# redefine LAPACK interface to schur
for (hseqr, elty) in
    ((:dhseqr_, :Float64), (:shseqr_, :Float32))
    @eval begin
        function hseqr!(H::StridedMatrix{$elty}, Z::StridedMatrix{$elty} = one(H))
            require_one_based_indexing(H, Z)
            chkstride1(H, Z)
            n = checksquare(H)
            checksquare(Z) == n || throw(DimensionMismatch())
            job = 'S'
            compz = 'V'
            ilo = 1
            ihi = n
            ldh = stride(H, 2)
            ldz = stride(Z, 2)
            wr = similar(H, $elty, n)
            wi = similar(H, $elty, n)
            work = Vector{$elty}(undef, 1)
            lwork = BlasInt(-1)
            info = Ref{BlasInt}()
            for i in 1:2  # first call returns lwork as work[1]
                ccall(
                    (@blasfunc($hseqr), liblapack),
                    Cvoid,
                    (
                        Ref{UInt8},
                        Ref{UInt8},
                        Ref{BlasInt},
                        Ref{BlasInt},
                        Ref{BlasInt},
                        Ptr{$elty},
                        Ref{BlasInt},
                        Ptr{$elty},
                        Ptr{$elty},
                        Ptr{$elty},
                        Ref{BlasInt},
                        Ptr{$elty},
                        Ref{BlasInt},
                        Ptr{BlasInt},
                        Clong,
                        Clong
                    ),
                    job,
                    compz,
                    n,
                    ilo,
                    ihi,
                    H,
                    ldh,
                    wr,
                    wi,
                    Z,
                    ldz,
                    work,
                    lwork,
                    info,
                    1,
                    1
                )
                chklapackerror(info[])
                if i == 1
                    lwork = BlasInt(real(work[1]))
                    resize!(work, lwork)
                end
            end
            return H, Z, complex.(wr, wi)
        end
    end
end

for (hseqr, elty, relty) in (
    (:zhseqr_,  :ComplexF64, :Float64),
    (:chseqr_,  :ComplexF32, :Float32)
)
    @eval begin
        function hseqr!(H::AbstractMatrix{$elty}, Z::AbstractMatrix{$elty} = one(H))
            require_one_based_indexing(H, Z)
            chkstride1(H, Z)
            n = checksquare(H)
            checksquare(Z) == n || throw(DimensionMismatch())
            job = 'S'
            compz = 'V'
            ilo = 1
            ihi = n
            ldh = stride(H, 2)
            ldz = stride(Z, 2)
            w = similar(H, $elty, n)
            work = Vector{$elty}(undef, 1)
            lwork = BlasInt(-1)
            info = Ref{BlasInt}()
            for i in 1:2  # first call returns lwork as work[1]
                ccall(
                    (@blasfunc($hseqr), liblapack),
                    Cvoid,
                    (
                        Ref{UInt8},
                        Ref{UInt8},
                        Ref{BlasInt},
                        Ref{BlasInt},
                        Ref{BlasInt},
                        Ptr{$elty},
                        Ref{BlasInt},
                        Ptr{$elty},
                        Ptr{$elty},
                        Ref{BlasInt},
                        Ptr{$elty},
                        Ref{BlasInt},
                        Ptr{BlasInt},
                        Clong,
                        Clong
                    ),
                    job,
                    compz,
                    n,
                    ilo,
                    ihi,
                    H,
                    ldh,
                    w,
                    Z,
                    ldz,
                    work,
                    lwork,
                    info,
                    1,
                    1
                )
                chklapackerror(info[])
                if i == 1
                    lwork = BlasInt(real(work[1]))
                    resize!(work, lwork)
                end
            end
            return H, Z, w
        end
    end
end
function eigsolve2(A, b, k::Int = min(4, size(A, 1)), m::Int = min(10, size(A, 1)); tol::Real = 1e-8, maxiter::Int = 200,
    isherm = ishermitian(A))

    M = typeof(reshape(b, length(b), 1))

    Ks = ExponentialUtilities.arnoldi(A, b, m = m)
    V = getV(Ks)
    H = getH(Ks)
    f = ones(eltype(V), m)

    Vₘ = view(V, :, 1:m)
    Hₘ = view(H, 1:m, :)
    qₘ = view(V, :, m+1)
    βeₘ = view(H, m+1, :)
    β = H[m+1, m]
    Uₘ = one(Hₘ)

    numops = m
    iter = 1
    while iter < maxiter && sum(abs.(f) .< tol) < k
        m = Ks.m
        p = m-k

        # println( A * Vₘ ≈ Vₘ * M(Hₘ) + qₘ * M(transpose(βeₘ)) )     # SHOULD BE TRUE

        Tₘ, tau = LAPACK.gehrd!(1, m, Hₘ)
        copyto!(Uₘ, Tₘ)
        LAPACK.orghr!(1, m, Uₘ, tau)
        Tₘ, Uₘ, values = hseqr!(Hₘ, Uₘ)
        p = sortperm(values, by = abs, rev = true)
        Tₘ, Uₘ = permuteschur!(Tₘ, Uₘ, p)

        mul!(f, view(Uₘ, m, 1:m), β)

        # println( A * Vₘ * Uₘ ≈ Vₘ * Uₘ * M(Tₘ) + qₘ * M(transpose(βeₘ)) * Uₘ )     # SHOULD BE TRUE

        copyto!(view(V, :, 1:k), view(Vₘ * M(Uₘ), :, 1:k))
        copyto!(view(V, :, k+1), qₘ)
        copyto!(view(H, k+1, 1:k), view(transpose(βeₘ) * Uₘ, 1:k))

        # println( A * view(V, :, 1:k) ≈ view(V, :, 1:k) * M(view(H, 1:k, 1:k)) + qₘ * M(transpose(view(transpose(βeₘ) * Uₘ, 1:k))) )     # SHOULD BE TRUE

        for j in k+1:m
            β = arnoldi_step!(j, m, A, V, H, size(V, 1), 0)
        end

        numops += m-k-1
        iter+=1
    end

    Tₘ, tau = LAPACK.gehrd!(1, m, Hₘ)
    copyto!(Uₘ, Tₘ)
    LAPACK.orghr!(1, m, Uₘ, tau)
    Tₘ, Uₘ, values = hseqr!(Hₘ, Uₘ)
    p = sortperm(values, by = abs, rev = true)
    Tₘ, Uₘ = permuteschur!(Tₘ, Uₘ, p)
    mul!(f, view(Uₘ, m, 1:m), β)

    vals = diag(view(Hₘ, 1:k, 1:k))
    select = Vector{BlasInt}(undef, 0)
    VR = LAPACK.trevc!('R', 'A', select, Tₘ)
    @inbounds for i in 1:size(VR, 2)
        normalize!(view(VR, :, i))
    end
    vecs = (Vₘ * M(Uₘ * VR))[:, 1:k]

    return vals, vecs, (iter, numops)
end
N = 1000
elty = ComplexF64

vals0 = Array(1:N)
vecs0 = rand(elty, N, N)

A = vecs0 \ diagm(vals0) * vecs0
A = sparse(A)
# droptol!(A, 1000)
v0 = rand(elty, N)

μ = 0 # SHIFT-INVERSE
A_s = A - μ * I

A_s2 = CuSparseMatrixCSR(A_s)
v02 = CuArray(v0)
P2 = ILU0gpu(ilu02(A_s2, 'O'))
solver = GmresSolver(size(A_s2)..., 20, typeof(v02))
Lmap = LinearMap{eltype(v02)}((y, x) -> copyto!(y, gmres!(solver, A_s2, x, M = P2, ldiv = true, restart=true).x), size(A_s2, 1))
vals, vecs, info = eigsolve2(Lmap, v02, 6, 20)
show(info)
vals = (1 .+ μ * vals) ./ vals
abs.(vals)
amontoison commented 1 year ago

Hi @albertomercurio,

I'm not sure to understand your issue, do you want an eigenvalue solver in Krylov.jl?

This package is dedicated to Krylov processes and Krylov methods for linear problems. it's not a catch-all package like IterativeSolvers.jl and KrylovKit.jl.

I'm open to add singular value / eigenvalue solvers but I'm not an expert in that field.

dpo commented 1 year ago

Those could easily go into an extension package though.

albertomercurio commented 1 year ago

I'm not sure to understand your issue, do you want an eigenvalue solver in Krylov.jl?

This package is dedicated to Krylov processes and Krylov methods for linear problems. it's not a catch-all package like IterativeSolvers.jl and KrylovKit.jl.

Ok, clear.

Anyway, the eigenvalues solvers use the Krylov subspaces, and I thought to this package since it supports these algorithms also on GPUs.