Closed j-fu closed 1 year ago
Actually, there may be an obstacle to differentiable sparse linear solver under the hood: level 3 blas routines. I do not think we can differentiate through that.
On Mon, Aug 29, 2022, 1:29 PM Jürgen Fuhrmann @.***> wrote:
Hi, this package is an incredible feat, which gets us very close to the availability of a differentiable sparse matrix solver.
I tried this out on a small test problem, there is yet a small distance to go :)
Here is a small test problem:
using Tensorsusing LinearAlgebrausing Sparspakusing SparseArrays function tridiagonal(p,n) T=typeof(p) b=T[p^i for i=1:n] a=fill(T(-0.1),n-1) c=fill(T(-0.1),n-1) Tridiagonal(a,b,c)end function f(p) n=20 M=Matrix(tridiagonal(p,n)) f=ones(n) sum(M\f)end df(x)=Tensors.gradient(f,x) function g(p) n=20 M=sparse(tridiagonal(p,n)) f=ones(n) pr = Sparspak.SpkProblem.Problem(n,n,nnz(M),zero(p)) Sparspak.SpkProblem.insparse!(pr, M) Sparspak.SpkProblem.infullrhs!(pr,f) s = Sparspak.SparseSolver.SparseSolver(pr) Sparspak.SparseSolver.solve!(s) sum(pr.x)end dg(x)=Tensors.gradient(g,x)
The result is:
julia> f(4.0) 0.33672179269725333
julia> df(4.0) -0.11378327767038768
julia> g(4.0) 0.33672179269725333
julia> dg(4.0) -10.000000059604645
julia> h=1.0e-5 1.0e-5
julia> (g(4.0+h)-g(4.0))/h -0.1137828930852791
So this indeed runs (after replacing Problem(... by Problem{IT,FT}(... here https://urldefense.com/v3/__https://github.com/PetrKryslUCSD/Sparspak.jl/blob/6e86df5c89a19d18cd9b056b09f24bfbc279e7ea/src/Problem/SpkProblem.jl*L165__;Iw!!Mih3wA!FEhlPMUuTvC4uupFwtgoZUU-xHe7vXb1L-rBiZXl7XRbI8uRF1rJfT1TedoDhLNfc7Bo8tRDdbu6O8ExkN6hkg95$, not sure if this warrants a PR...), but the result is incorrect.
— Reply to this email directly, view it on GitHub https://urldefense.com/v3/__https://github.com/PetrKryslUCSD/Sparspak.jl/issues/4__;!!Mih3wA!FEhlPMUuTvC4uupFwtgoZUU-xHe7vXb1L-rBiZXl7XRbI8uRF1rJfT1TedoDhLNfc7Bo8tRDdbu6O8ExkIZ-Ol-W$, or unsubscribe https://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/ACLGGWFLWPLFDAWGNDALWGDV3UMTLANCNFSM577DU73Q__;!!Mih3wA!FEhlPMUuTvC4uupFwtgoZUU-xHe7vXb1L-rBiZXl7XRbI8uRF1rJfT1TedoDhLNfc7Bo8tRDdbu6O8ExkAT5aLMY$ . You are receiving this because you are subscribed to this thread.Message ID: @.***>
I see... but having the basic algorithm infrastructure available may nevertheless help with getting there. Will think more about it when I (hopefully) find time.
It might be possible to write a generic dgemm! calling back to the generic linear algebra methods of Julia which in case of Float64, Float32 and their complex equivalents specializes on what you have written. May the same for the other blas/lapack methods you wrote.
Correct. All those BLAS will have to be rewritten.
I started to work on a PR for this. Will have to manage my time, so it may take a while. Two remarks though:
I am close to having it - code runs e.g. with MultiFloats and ForwardDiff.Dual, the problem from the initial post works.
There still seem to be occasionally pivoting and indexing problems (my blas replacement runs with bounds-checking) in the tests based on the sprand matrices. I will try to catch some corner cases before submitting.
One question though: What I did adds just new generic methods to dgemm! and the like, and of course tests. So in principle we have two options: a) keep going with a PR to Sparspak.jl b) Have another package (e.g. SparspackGenericLinalg.jl) which just imports dgemm! etc from Sparspak.jl and extends them (didn't test this yet)
What is your preference ? I have a slight preference for a).
EDIT: The branch is on https://github.com/j-fu/Sparspak.jl/tree/generic-blas
Regarding #10 : I made another test regarding autodiff, the result stays the same: ForwardDiff.jl and FiniteDiff.jl work well. ReverseDiff.jl and Zygote.jl miss the possibility to mutate (setindex!), and Enzyme makes a rather raw impression.
I think with the advent of 0.3.x we can close this.
Hi, this package is an incredible feat, which gets us very close to the availability of a differentiable sparse matrix solver.
I tried this out on a small test problem, there is yet a small distance to go :)
Here is a small test problem:
The result is:
So this indeed runs (after replacing
Problem(...
byProblem{IT,FT}(...
here, not sure if this warrants a PR...), but the result is incorrect.