JuliaLinearAlgebra / AlgebraicMultigrid.jl

Algebraic Multigrid in Julia
Other
117 stars 23 forks source link

Hessian vector products #67

Open omalled opened 4 years ago

omalled commented 4 years ago

Would it be possible to get some code like below working? The first example with the function f is meant to show that this definition of hessian_vector_product can work. The second example shows that this fails with g, which uses AlgebraicMultigrid. If Hessian-vector products could be computed efficiently this way, it would be really useful.

using Test
import AlgebraicMultigrid
import ForwardDiff
import LinearAlgebra
import SparseArrays
import Zygote

hessian_vector_product(f, x, v) = ForwardDiff.jacobian(s->Zygote.gradient(f, x + s[1] * v)[1], [0.0])[:]

n = 4
A = randn(n, n)
hessian = A + A'
f(x) = LinearAlgebra.dot(x, A * x) 
x = randn(n)
v = randn(n)
hvp1 = hessian_vector_product(f, x, v)
hvp2 = hessian * v
@test hvp1 ≈ hvp2#the hessian_vector_product plausibly works!

function g(x)
    k = x[1:n + 1]
    B = SparseArrays.spdiagm(0=>k[1:end - 1] + k[2:end], -1=>-k[2:end - 1], 1=>-k[2:end - 1])
    ml = AlgebraicMultigrid.ruge_stuben(B)
    return sum(AlgebraicMultigrid.solve(ml, x[N + 2:end]))
end
x = randn(2 * n + 1)
v = randn(2 * n + 1)
hessian_vector_product(g, x, v)#seems to fail during the coarse solve in AlgebraicMultigrid
ranjanan commented 4 years ago

Pasting the stack trace here for reference:

julia> hessian_vector_product(g, x, v)#seems to fail during the coarse solve in AlgebraicMultigrid
ERROR: MethodError: no method matching svd!(::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},2}; full=false, alg=LinearAlgebra.DivideAndConque
r())
Closest candidates are:
  svd!(::LinearAlgebra.AbstractTriangular; kwargs...) at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\triangular.jl:2672
  svd!(::StridedArray{T, 2}; full, alg) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\svd.jl:93
  svd!(::StridedArray{T, 2}, ::StridedArray{T, 2}) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\s
vd.jl:363 got unsupported keyword arguments "full", "alg"
  ...
Stacktrace:
 [1] svd(::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},2}; full::Bool, alg::LinearAlgebra.DivideAndConquer) at D:\buildbot\worker\package_w
in64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\svd.jl:158
 [2] pinv(::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},2}; atol::Float64, rtol::Float64) at D:\buildbot\worker\package_win64\build\usr\sha
re\julia\stdlib\v1.5\LinearAlgebra\src\dense.jl:1356
 [3] pinv at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LinearAlgebra\src\dense.jl:1335 [inlined]
 [4] #adjoint#762 at C:\Users\Ranjan Anantharaman\.julia\packages\Zygote\rqvFi\src\lib\array.jl:400 [inlined]
 [5] adjoint at .\none:0 [inlined]
 [6] _pullback at C:\Users\Ranjan Anantharaman\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [7] Pinv at C:\Users\Ranjan Anantharaman\.julia\packages\AlgebraicMultigrid\RU7pA\src\multilevel.jl:57 [inlined]
 [8] _pullback(::Zygote.Context, ::Type{AlgebraicMultigrid.Pinv{ForwardDiff.Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1}}}, ::SparseArrays.SparseMatrixCSC{For
wardDiff.Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},Int64}) at C:\Users\Ranjan Anantharaman\.julia\packages\Zygote\rqvFi\src\compiler\interface2.jl:0
 [9] Pinv at C:\Users\Ranjan Anantharaman\.julia\packages\AlgebraicMultigrid\RU7pA\src\multilevel.jl:59 [inlined]
 [10] _pullback(::Zygote.Context, ::Type{AlgebraicMultigrid.Pinv}, ::SparseArrays.SparseMatrixCSC{ForwardDiff.Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},Int
64}) at C:\Users\Ranjan Anantharaman\.julia\packages\Zygote\rqvFi\src\compiler\interface2.jl:0
 [11] #ruge_stuben#13 at C:\Users\Ranjan Anantharaman\.julia\packages\AlgebraicMultigrid\RU7pA\src\classical.jl:44 [inlined]
 [12] _pullback(::Zygote.Context, ::AlgebraicMultigrid.var"##ruge_stuben#13", ::AlgebraicMultigrid.Classical{Float64}, ::AlgebraicMultigrid.RS, ::AlgebraicMultigrid.GaussSeidel{AlgebraicMultigrid.SymmetricSwee
p}, ::AlgebraicMultigrid.GaussSeidel{AlgebraicMultigrid.SymmetricSweep}, ::Int64, ::Int64, ::Type{AlgebraicMultigrid.Pinv}, ::typeof(AlgebraicMultigrid.ruge_stuben), ::SparseArrays.SparseMatrixCSC{ForwardDiff.
Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},Int64}, ::Type{Val{1}}) at C:\Users\Ranjan Anantharaman\.julia\packages\Zygote\rqvFi\src\compiler\interface2.jl:0
 [13] ruge_stuben at C:\Users\Ranjan Anantharaman\.julia\packages\AlgebraicMultigrid\RU7pA\src\classical.jl:20 [inlined] (repeats 2 times)
 [14] _pullback(::Zygote.Context, ::typeof(AlgebraicMultigrid.ruge_stuben), ::SparseArrays.SparseMatrixCSC{ForwardDiff.Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float
64,1},Int64}) at C:\Users\Ranjan Anantharaman\.julia\packages\Zygote\rqvFi\src\compiler\interface2.jl:0
 [15] g at .\REPL[20]:4 [inlined]
 [16] _pullback(::Zygote.Context, ::typeof(g), ::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},1}) at C:\Users\Ranjan Anantharaman\.julia\pac
kages\Zygote\rqvFi\src\compiler\interface2.jl:0
 [17] _pullback(::Function, ::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},1}) at C:\Users\Ranjan Anantharaman\.julia\packages\Zygote\rqvFi\
src\compiler\interface.jl:38
 [18] pullback(::Function, ::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},1}) at C:\Users\Ranjan Anantharaman\.julia\packages\Zygote\rqvFi\s
rc\compiler\interface.jl:44
 [19] gradient(::Function, ::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},1}) at C:\Users\Ranjan Anantharaman\.julia\packages\Zygote\rqvFi\s
rc\compiler\interface.jl:53
 [20] (::var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}})(::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},1}) at .\REPL[10]:1
 [21] vector_mode_dual_eval at C:\Users\Ranjan Anantharaman\.julia\packages\ForwardDiff\sdToQ\src\apiutils.jl:37 [inlined]
 [22] vector_mode_jacobian(::var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}}, ::Array{Float64,1}, ::ForwardDiff.JacobianConfig{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Floa
t64},Float64,1,Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},1}}) at C:\Users\Ranjan Anantharaman\.julia\packages\ForwardDiff\sdToQ\src\jacob
ian.jl:140
 [23] jacobian(::Function, ::Array{Float64,1}, ::ForwardDiff.JacobianConfig{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1,Array{ForwardDiff.Dual{ForwardDiff.Tag{var"
#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},1}}, ::Val{true}) at C:\Users\Ranjan Anantharaman\.julia\packages\ForwardDiff\sdToQ\src\jacobian.jl:17
 [24] jacobian(::Function, ::Array{Float64,1}, ::ForwardDiff.JacobianConfig{ForwardDiff.Tag{var"#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1,Array{ForwardDiff.Dual{ForwardDiff.Tag{var"
#1#2"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1},1}}) at C:\Users\Ranjan Anantharaman\.julia\packages\ForwardDiff\sdToQ\src\jacobian.jl:15 (repeats 2 times)
 [25] hessian_vector_product(::Function, ::Array{Float64,1}, ::Array{Float64,1}) at .\REPL[10]:1
 [26] top-level scope at REPL[23]:1
ChrisRackauckas commented 3 years ago

Do https://github.com/JuliaLinearAlgebra/GenericSVD.jl and see if that brings in the right dispatch. You may need to remove the kwargs from the svd! call in order to have the available dispatch.

omalled commented 3 years ago

Here's an updated version with using GenericSVD that produces a StackOverflowError instead of the MethodError in the old stack trace:

using Test
import AlgebraicMultigrid
import ForwardDiff
using GenericSVD
import LinearAlgebra
import SparseArrays
import Zygote

hessian_vector_product(f, x, v) = ForwardDiff.jacobian(s->Zygote.gradient(f, x + s[1] * v)[1], [0.0])[:]

n = 4
A = randn(n, n)
hessian = A + A'
f(x) = LinearAlgebra.dot(x, A * x) 
x = randn(n)
v = randn(n)
hvp2 = hessian * v
hvp1 = hessian_vector_product(f, x, v)
@test hvp1 ≈ hvp2#the hessian_vector_product plausibly works!

function g(x)
    k = x[1:n + 1]
    B = SparseArrays.spdiagm(0=>k[1:end - 1] + k[2:end], -1=>-k[2:end - 1], 1=>-k[2:end - 1])
    ml = AlgebraicMultigrid.ruge_stuben(B)
    return sum(AlgebraicMultigrid.solve(ml, x[N + 2:end]))
end
x = randn(2 * n + 1)
v = randn(2 * n + 1)
hessian_vector_product(g, x, v)#stack overflow

Here's the stack trace:

┌ Warning: keyword `alg` ignored in generic svd!
└ @ GenericSVD ~/.julia/packages/GenericSVD/cT5Cu/src/GenericSVD.jl:12
ERROR: StackOverflowError:
Stacktrace:
 [1] givensAlgorithm(::ForwardDiff.Dual{ForwardDiff.Tag{var"#7#8"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1}, ::ForwardDiff.Dual{ForwardDiff.Tag{var"#7#8"{typeof(g),Array{Float64,1},Array{Float64,1}},Float64},Float64,1}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/givens.jl:251 (repeats 79984 times)

Should I open an issue over at GenericSVD?

ChrisRackauckas commented 3 years ago

yeah that looks like some type assumption was violated in GenericSVD

ViralBShah commented 3 years ago

@ranjanan Whenever you have a moment - let's get this one done.

ViralBShah commented 3 years ago

Also pinging @DhairyaLGandhi. This is a bit urgent to resolve.

ranjanan commented 3 years ago

https://github.com/JuliaLinearAlgebra/GenericSVD.jl/issues/25#issuecomment-778839017