SciML / LinearSolve.jl

LinearSolve.jl: High-Performance Unified Interface for Linear Solvers in Julia. Easily switch between factorization and Krylov methods, add preconditioners, and all in one interface.
https://docs.sciml.ai/LinearSolve/stable/
Other
243 stars 52 forks source link

Gradients and Hessian-vector products #152

Open omalled opened 2 years ago

omalled commented 2 years ago

I saw in the docs that "the current algorithms should support automatic differentiation," and played around with it a bit. My ultimate goal is to get a Hessian-vector product working (similar to this issue in AlgebraicMultigrid, which was never resolved despite some effort being made). However, I wasn't able to get a Hessian-vector product or even a gradient working in a relatively simple example:

using Test
import ForwardDiff
import LinearAlgebra
import LinearSolve
import SparseArrays
import Zygote

hessian_vector_product(f, x, v) = ForwardDiff.jacobian(s->Zygote.gradient(f, x + s[1] * v)[1], [0.0])[:]

n = 4
A = randn(n, n)
hessian = A + A'
f(x) = LinearAlgebra.dot(x, A * x) 
x = randn(n)
v = randn(n)
hvp1 = hessian_vector_product(f, x, v)
hvp2 = hessian * v
@test hvp1 ≈ hvp2#the hessian_vector_product plausibly works!

function g(x)
    k = x[1:n + 1]
    B = SparseArrays.spdiagm(0=>k[1:end - 1] + k[2:end], -1=>-k[2:end - 1], 1=>-k[2:end - 1])
    prob = LinearSolve.LinearProblem(B, x[n + 2:end])
    sol = LinearSolve.solve(prob)
    return sum(sol.u)
end
x = randn(2 * n + 1)
v = randn(2 * n + 1)
Zygote.gradient(g, x)#Can't differentiate foreigncall expression
hessian_vector_product(g, x, v)#LoadError: MethodError: no method matching SuiteSparse.UMFPACK.UmfpackLU...

Is there any chance to get these derivatives working with LinearSolve? It would be really great, especially the Hessian-vector products. Thanks for your help and your great work on this package! Please let me know if there's something I can do to help get this working 😄

ChrisRackauckas commented 2 years ago

Yeah... that's why it said "should" in the "Roadmap" section of the docs 😅. A lot of cases end up working out since it just differentiates the algorithm, and things like lu/qr/svd have ChainRules defined on them so a lot of cases "accidentally" work. But what we need to do is lower to solve_ab(A,b,sensealg,alg) etc. and then define the chain rule on that, which is the same as https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/arraymath.jl#L336-L359 .

Then it just needs a solve on the adjoint, i.e. https://github.com/SciML/LinearSolve.jl/issues/92