PetrKryslUCSD / Sparspak.jl

Direct solution of large sparse systems of linear algebraic equations in pure Julia
MIT License
37 stars 7 forks source link

Differentiable sparse solver #4

Closed j-fu closed 1 year ago

j-fu commented 2 years ago

Hi, this package is an incredible feat, which gets us very close to the availability of a differentiable sparse matrix solver.

I tried this out on a small test problem, there is yet a small distance to go :)

Here is a small test problem:

using Tensors
using LinearAlgebra
using Sparspak
using SparseArrays

function tridiagonal(p,n)
    T=typeof(p)
    b=T[p^i for i=1:n]
    a=fill(T(-0.1),n-1)
    c=fill(T(-0.1),n-1)
    Tridiagonal(a,b,c)
end

function f(p)
    n=20
    M=Matrix(tridiagonal(p,n))
    f=ones(n)
    sum(M\f)
end

df(x)=Tensors.gradient(f,x)

function g(p)
    n=20
    M=sparse(tridiagonal(p,n))
    f=ones(n)
    pr = Sparspak.SpkProblem.Problem(n,n,nnz(M),zero(p))
    Sparspak.SpkProblem.insparse!(pr, M)
    Sparspak.SpkProblem.infullrhs!(pr,f)
    s = Sparspak.SparseSolver.SparseSolver(pr)
    Sparspak.SparseSolver.solve!(s)    
    sum(pr.x)
end

dg(x)=Tensors.gradient(g,x)

The result is:

julia> f(4.0)
0.33672179269725333

julia> df(4.0)
-0.11378327767038768

julia> g(4.0)
0.33672179269725333

julia> dg(4.0)
-10.000000059604645

julia> h=1.0e-5
1.0e-5

julia> (g(4.0+h)-g(4.0))/h
-0.1137828930852791

So this indeed runs (after replacing Problem(... by Problem{IT,FT}(... here, not sure if this warrants a PR...), but the result is incorrect.

PetrKryslUCSD commented 2 years ago

Actually, there may be an obstacle to differentiable sparse linear solver under the hood: level 3 blas routines. I do not think we can differentiate through that.

On Mon, Aug 29, 2022, 1:29 PM Jürgen Fuhrmann @.***> wrote:

Hi, this package is an incredible feat, which gets us very close to the availability of a differentiable sparse matrix solver.

I tried this out on a small test problem, there is yet a small distance to go :)

Here is a small test problem:

using Tensorsusing LinearAlgebrausing Sparspakusing SparseArrays function tridiagonal(p,n) T=typeof(p) b=T[p^i for i=1:n] a=fill(T(-0.1),n-1) c=fill(T(-0.1),n-1) Tridiagonal(a,b,c)end function f(p) n=20 M=Matrix(tridiagonal(p,n)) f=ones(n) sum(M\f)end df(x)=Tensors.gradient(f,x) function g(p) n=20 M=sparse(tridiagonal(p,n)) f=ones(n) pr = Sparspak.SpkProblem.Problem(n,n,nnz(M),zero(p)) Sparspak.SpkProblem.insparse!(pr, M) Sparspak.SpkProblem.infullrhs!(pr,f) s = Sparspak.SparseSolver.SparseSolver(pr) Sparspak.SparseSolver.solve!(s) sum(pr.x)end dg(x)=Tensors.gradient(g,x)

The result is:

julia> f(4.0) 0.33672179269725333

julia> df(4.0) -0.11378327767038768

julia> g(4.0) 0.33672179269725333

julia> dg(4.0) -10.000000059604645

julia> h=1.0e-5 1.0e-5

julia> (g(4.0+h)-g(4.0))/h -0.1137828930852791

So this indeed runs (after replacing Problem(... by Problem{IT,FT}(... here https://urldefense.com/v3/__https://github.com/PetrKryslUCSD/Sparspak.jl/blob/6e86df5c89a19d18cd9b056b09f24bfbc279e7ea/src/Problem/SpkProblem.jl*L165__;Iw!!Mih3wA!FEhlPMUuTvC4uupFwtgoZUU-xHe7vXb1L-rBiZXl7XRbI8uRF1rJfT1TedoDhLNfc7Bo8tRDdbu6O8ExkN6hkg95$, not sure if this warrants a PR...), but the result is incorrect.

— Reply to this email directly, view it on GitHub https://urldefense.com/v3/__https://github.com/PetrKryslUCSD/Sparspak.jl/issues/4__;!!Mih3wA!FEhlPMUuTvC4uupFwtgoZUU-xHe7vXb1L-rBiZXl7XRbI8uRF1rJfT1TedoDhLNfc7Bo8tRDdbu6O8ExkIZ-Ol-W$, or unsubscribe https://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/ACLGGWFLWPLFDAWGNDALWGDV3UMTLANCNFSM577DU73Q__;!!Mih3wA!FEhlPMUuTvC4uupFwtgoZUU-xHe7vXb1L-rBiZXl7XRbI8uRF1rJfT1TedoDhLNfc7Bo8tRDdbu6O8ExkAT5aLMY$ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

j-fu commented 2 years ago

I see... but having the basic algorithm infrastructure available may nevertheless help with getting there. Will think more about it when I (hopefully) find time.

j-fu commented 2 years ago

It might be possible to write a generic dgemm! calling back to the generic linear algebra methods of Julia which in case of Float64, Float32 and their complex equivalents specializes on what you have written. May the same for the other blas/lapack methods you wrote.

PetrKryslUCSD commented 2 years ago

Correct. All those BLAS will have to be rewritten.

j-fu commented 2 years ago

I started to work on a PR for this. Will have to manage my time, so it may take a while. Two remarks though:

j-fu commented 2 years ago

I am close to having it - code runs e.g. with MultiFloats and ForwardDiff.Dual, the problem from the initial post works.

There still seem to be occasionally pivoting and indexing problems (my blas replacement runs with bounds-checking) in the tests based on the sprand matrices. I will try to catch some corner cases before submitting.

One question though: What I did adds just new generic methods to dgemm! and the like, and of course tests. So in principle we have two options: a) keep going with a PR to Sparspak.jl b) Have another package (e.g. SparspackGenericLinalg.jl) which just imports dgemm! etc from Sparspak.jl and extends them (didn't test this yet)

What is your preference ? I have a slight preference for a).

EDIT: The branch is on https://github.com/j-fu/Sparspak.jl/tree/generic-blas

j-fu commented 2 years ago

Regarding #10 : I made another test regarding autodiff, the result stays the same: ForwardDiff.jl and FiniteDiff.jl work well. ReverseDiff.jl and Zygote.jl miss the possibility to mutate (setindex!), and Enzyme makes a rather raw impression.

j-fu commented 1 year ago

I think with the advent of 0.3.x we can close this.