CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
215 stars 46 forks source link

Failure to combine `SparseDiffTools.autoback_hesvec` and `GCNConv` #125

Closed newalexander closed 2 years ago

newalexander commented 2 years ago

Hello! Nice work on the library; it is very usable. I'm trying to calculate the hessian-vector product of a loss function involving GNNGraph datapoints and a GNNChain model. I've been using the SparseDiffTools.jl function autoback_hesvec for this, which implements ForwardDiff.jl over Zygote.jl for the hessian-vector calculation. However, this function is failing in the GraphNeuralNetworks.jl setting. The other hessian-vector functions in SparseDiffTools.jl do work, and an analogously-constructed calculation using only Flux works.

using GraphNeuralNetworks, Flux, Graphs, ForwardDiff, Random, SparseDiffTools

function gnn_test()
    Random.seed!(1234)

    g = GNNGraph(erdos_renyi(10,  30), ndata=rand(Float32, 3, 10), gdata=rand(Float32, 2))

    m = GNNChain(GCNConv(3 => 2, tanh), GlobalPool(+))
    ps, re = Flux.destructure(m)  # primal vector and restructure function
    ts = rand(Float32, size(ps))  # tangent vector

    loss(_ps) = Flux.Losses.mse(re(_ps)(g, g.ndata.x), g.gdata.u)

    numback_hesvec(loss, ps, ts) |> println  # works
    numback_hesvec(loss, ps, ts)  |> println  # works
    numauto_hesvec(loss, ps, ts)  |> println  # works
    autoback_hesvec(loss, ps, ts) |> println  # fails
end

function flux_test()
    Random.seed!(1234)

    x = rand(Float32, 10, 3)
    y = rand(Float32, 2, 3)

    m = Chain(Dense(10, 4, tanh), Dense(4, 2))
    ps, re = Flux.destructure(m)  # primal vector and restructure function
    ts = rand(Float32, size(ps))  # tangent vector

    loss(_ps) = Flux.Losses.mse(re(_ps)(x), y)

    numback_hesvec(loss, ps, ts) |> println  # works
    numback_hesvec(loss, ps, ts)  |> println  # works
    numauto_hesvec(loss, ps, ts)  |> println  # works
    autoback_hesvec(loss, ps, ts) |> println  # works
end

The full error message:

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1})
Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/rounding.jl:200
  (::Type{T})(::T) where T<:Number at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/boot.jl:770
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/char.jl:50
  ...
Stacktrace:
  [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1})
    @ Base ./number.jl:7
  [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}, i1::Int64)
    @ Base ./array.jl:903
  [3] (::ChainRulesCore.ProjectTo{SparseArrays.SparseMatrixCSC, NamedTuple{(:element, :axes, :rowval, :nzranges, :colptr), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Vector{Int64}, Vector{UnitRange{Int64}}, Vector{Int64}}}})(dx::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/uxrij/src/projection.jl:580
  [4] #1335
    @ ~/.julia/packages/ChainRules/3HAQW/src/rulesets/Base/arraymath.jl:37 [inlined]
  [5] unthunk
    @ ~/.julia/packages/ChainRulesCore/uxrij/src/tangent_types/thunks.jl:197 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:104 [inlined]
  [7] map
    @ ./tuple.jl:223 [inlined]
  [8] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:105 [inlined]
  [9] ZBack
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:204 [inlined]
 [10] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/msgpass.jl:189 [inlined]
 [11] (::typeof(∂(propagate)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/msgpass.jl:68 [inlined]
 [13] (::typeof(∂(#propagate#84)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/msgpass.jl:68 [inlined]
 [15] (::typeof(∂(propagate##kw)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/layers/conv.jl:103 [inlined]
 [17] (::typeof(∂(λ)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/layers/conv.jl:80 [inlined]
 [19] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/layers/basic.jl:125 [inlined]
 [20] (::typeof(∂(applylayer)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [21] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/layers/basic.jl:137 [inlined]
 [22] (::typeof(∂(applychain)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [23] Pullback
    @ ~/.julia/packages/GraphNeuralNetworks/HAl1C/src/layers/basic.jl:139 [inlined]
 [24] (::typeof(∂(λ)))(Δ::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [25] Pullback
    @ ~/JuliaProjects/GraphNetworkLayers/test/fwd.jl:15 [inlined]
 [26] (::typeof(∂(λ)))(Δ::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [27] (::Zygote.var"#57#58"{typeof(∂(λ))})(Δ::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:41
 [28] gradient(f::Function, args::Vector{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:76
 [29] (::SparseDiffTools.var"#78#79"{var"#loss#5"{Flux.var"#66#68"{GNNChain{Tuple{GCNConv{Matrix{Float32}, Vector{Float32}, typeof(tanh)}, GlobalPool{typeof(+)}}}}, GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}})(x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ SparseDiffTools ~/.julia/packages/SparseDiffTools/9lSLn/src/differentiation/jaches_products_zygote.jl:39
 [30] autoback_hesvec(f::Function, x::Vector{Float32}, v::Vector{Float32})
    @ SparseDiffTools ~/.julia/packages/SparseDiffTools/9lSLn/src/differentiation/jaches_products_zygote.jl:41
 [31] gnn_test()
    @ Main [script location]
 [32] top-level scope
    @ REPL[8]:1
 [33] top-level scope
    @ ~/.julia/packages/CUDA/bki2w/src/initialization.jl:52
CarloLucibello commented 2 years ago

Thanks for this very clear report! The issue could be in Chainrule.ProjectTo{SparseMatrixCSC} not being ForwardDiff.Dual friendly.

CarloLucibello commented 2 years ago

It works fine for the generic message passing framework (e.g. for th GATConv)

julia> using GraphNeuralNetworks, Flux, Graphs, ForwardDiff, Random, SparseDiffTools

julia> function gnn_test()
           Random.seed!(1234)

           g = GNNGraph(erdos_renyi(10,  30), ndata=rand(Float32, 3, 10))
           m = GATConv(3 => 2, tanh)
           ps, re = Flux.destructure(m)  # primal vector and restructure function
           ts = rand(Float32, size(ps))  # tangent vector

           loss(_ps) = sum(re(_ps)(g, g.ndata.x))

           numback_hesvec(loss, ps, ts) |> println  # works
           numback_hesvec(loss, ps, ts)  |> println  # works
           numauto_hesvec(loss, ps, ts)  |> println  # works
           autoback_hesvec(loss, ps, ts) |> println  # fails for GCNConv, works for GATConv
       end
gnn_test (generic function with 1 method)

julia> gnn_test()
Float32[4.4092455, -0.75034475, 2.9557762, -0.39457783, 2.5744586, 0.34606418, 8.512531, 0.21992864, -0.2590933, -0.5618502, 0.4025826, -1.1103446]
Float32[4.4092455, -0.75034475, 2.9557762, -0.39457783, 2.5744586, 0.34606418, 8.512531, 0.21992864, -0.2590933, -0.5618502, 0.4025826, -1.1103446]
Float32[4.4092455, -0.751315, 2.9551294, -0.39425442, 2.5746205, 0.34541732, 8.512531, 0.21992864, -0.25904277, -0.56184894, 0.40250173, -1.1103705]
Float32[4.409217, -0.7511167, 2.9556952, -0.3946823, 2.5747547, 0.34566167, 8.511975, 0.21927829, -0.25907597, -0.5618338, 0.40243906, -1.1103226]

but doesn't like operations involving sparse matrices (i.e. what happens for GCNConv`).

CarloLucibello commented 2 years ago

The problem seems a very generic one not strictly related to GNN.jl. Here is a MWE:

julia> using SparseArrays, SparseDiffTools

julia> x, t = rand(5), rand(5);

julia> A = sprand(5, 5, 0.5);

julia> loss(x) = sum(tanh.(A * x));

julia> numback_hesvec(loss, x, t) # works
5-element Vector{Float64}:
 -0.349703846209146
 -1.210662833747414
 -1.4030571895355597
 -0.47786341057923254
 -0.9171474544184983

julia> autoback_hesvec(loss, x, t)
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1})
Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at ~/julia/julia-1.7.1/share/julia/base/rounding.jl:200
  (::Type{T})(::T) where T<:Number at ~/julia/julia-1.7.1/share/julia/base/boot.jl:770
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at ~/julia/julia-1.7.1/share/julia/base/char.jl:50
  ...
Stacktrace:
  [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1})
    @ Base ./number.jl:7
  [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1}, i1::Int64)
    @ Base ./array.jl:903
  [3] (::ChainRulesCore.ProjectTo{SparseMatrixCSC, NamedTuple{(:element, :axes, :rowval, :nzranges, :colptr), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Vector{Int64}, Vector{UnitRange{Int64}}, Vector{Int64}}}})(dx::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/uxrij/src/projection.jl:580
  [4] #1334
    @ ~/.julia/packages/ChainRules/GRzER/src/rulesets/Base/arraymath.jl:36 [inlined]
  [5] unthunk
    @ ~/.julia/packages/ChainRulesCore/uxrij/src/tangent_types/thunks.jl:197 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:104 [inlined]
  [7] map
    @ ./tuple.jl:223 [inlined]
  [8] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:105 [inlined]
  [9] ZBack
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:204 [inlined]
 [10] Pullback
    @ ./REPL[40]:1 [inlined]
 [11] (::typeof(∂(loss)))(Δ::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#57#58"{typeof(∂(loss))})(Δ::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:41
 [13] gradient(f::Function, args::Vector{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:76
 [14] (::SparseDiffTools.var"#78#79"{typeof(loss)})(x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1}})
    @ SparseDiffTools ~/.julia/packages/SparseDiffTools/b2cgD/src/differentiation/jaches_products_zygote.jl:39
 [15] autoback_hesvec(f::Function, x::Vector{Float64}, v::Vector{Float64})
    @ SparseDiffTools ~/.julia/packages/SparseDiffTools/b2cgD/src/differentiation/jaches_products_zygote.jl:41
 [16] top-level scope
    @ REPL[47]:1
 [17] top-level scope
    @ ~/.julia/packages/CUDA/bki2w/src/initialization.jl:52

Hi think this issue should be reported to SparseDiffTools.jl and probably fixed in ChainRules.jl