CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
210 stars 47 forks source link

Training edge weights as well as neural network #443

Open jarroyoe opened 1 week ago

jarroyoe commented 1 week ago

I’m trying to model Graph NODEs integrating GraphNeuralNetworks.jl and OrdinaryDiffEq.jl. I am trying to learn both the neural network parameters as well as the weights of the edges, so I have to manually modify the Flux parameters during prediction. When I run the following MWE:

using Graphs, GraphNeuralNetworks, Flux, OrdinaryDiffEq, ComponentArrays, Zygote, SciMLSensitivity

time = 1:10
x0 = rand(9)
obs = rand(9,10)

fullGraph = GNNGraph(complete_digraph(3))

layer1 = GCNConv(3 => 10,tanh,use_edge_weight=true)
layer2 = GCNConv(10 => 3,use_edge_weight=true)

chain = GNNChain(layer1,layer2)
pinit = ComponentArray{Float32}(weights = rand(ne(fullGraph)),
        layer1 = f64(layer1.weight),layer2 = f64(layer2.weight))

function predict(p)
    fullGraph = GNNGraph(complete_digraph(3))
    fullGraph = set_edge_weight(fullGraph,p.weights)
    chain.layers[1].weight .= p.layer1
    chain.layers[2].weight .= p.layer2

    function nn!(du,u,p,t)
        uGraph = reshape(u,(3,3))
        dGraph = reshape(chain(fullGraph,uGraph),(3*3))
        du .= dGraph
    end
    prob = ODEProblem(nn!,x0,(time[1],time[end]),saveat=time)
    sol = solve(prob)
    return Array(sol)
end

function loss_function(p)
    pred = predict(p)

    sum(abs2,pred .- obs)
end

Zygote.gradient(loss_function,pinit)

I get the following error:

ERROR: BoundsError: attempt to access 10-element UnitRange{Int64} at index [0]
Stacktrace:
  [1] throw_boundserror(A::UnitRange{Int64}, I::Int64)
    @ Base .\abstractarray.jl:737
  [2] getindex
    @ .\range.jl:930 [inlined]
  [3] (::SciMLSensitivity.ReverseLossCallback{…})(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ SciMLSensitivity C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\adjoint_common.jl:530
  [4] #111
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqCallbacks\9fKPq\src\preset_time.jl:58 [inlined]
  [5] apply_discrete_callback!
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\callbacks.jl:613 [inlined]
  [6] apply_discrete_callback!
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\callbacks.jl:628 [inlined]
  [7] handle_callbacks!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\integrators\integrator_utils.jl:349
  [8] _loopfooter!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\integrators\integrator_utils.jl:254
  [9] loopfooter!
    @ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\integrators\integrator_utils.jl:207 [inlined]
 [10] solve!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:558
 [11] #__solve#670
    @ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:7 [inlined]
 [12] __solve
    @ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:1 [inlined]
 [13] solve_call(_prob::ODEProblem{…}, args::CompositeAlgorithm{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:612
 [14] solve_call
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:569 [inlined]
 [15] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::SciMLBase.NullParameters, args::CompositeAlgorithm{…}; kwargs::@Kwargs{…})
    @ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1080
 [16] solve_up
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1066 [inlined]
 [17] #solve#51
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1003 [inlined]
 [18] solve
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:993 [inlined]
 [19] #__solve#675
    @ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:547 [inlined]
 [20] __solve
    @ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:546 [inlined]
 [21] solve_call(_prob::ODEProblem{…}, args::Nothing; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:612
 [22] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::SciMLBase.NullParameters, args::Nothing; kwargs::@Kwargs{…})
    @ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1072
 [23] solve_up
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1066 [inlined]
 [24] solve(prob::ODEProblem{…}, args::Nothing; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
    @ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1003
 [25] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::QuadratureAdjoint{…}, alg::Nothing; t::UnitRange{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, callback::Nothing, kwargs::@Kwargs{…})
    @ SciMLSensitivity C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\quadrature_adjoint.jl:340
 [26] _adjoint_sensitivities
    @ C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\quadrature_adjoint.jl:328 [inlined]
 [27] #adjoint_sensitivities#63
    @ C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\sensitivity_interface.jl:386 [inlined]
 [28] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#314"{…})(Δ::ODESolution{…})
    @ SciMLSensitivity C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\concrete_solve.jl:582
 [29] ZBack
    @ C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\chainrules.jl:211 [inlined]
 [30] (::Zygote.var"#291#292"{…})(Δ::ODESolution{…})
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\lib\lib.jl:206
 [31] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
 [32] #solve#51
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1003 [inlined]
 [33] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [34] #291
    @ C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
 [35] #2169#back
    @ C:\Users\JArroyo-Esquivel\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [36] solve
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:993 [inlined]
 [37] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [38] predict
    @ .\REPL[11]:13 [inlined]
 [39] (::Zygote.Pullback{Tuple{typeof(predict), ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}}}, Any})(Δ::Matrix{Float64})
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [40] loss_function
    @ .\REPL[12]:2 [inlined]
 [41] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface.jl:91
 [42] gradient(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface.jl:148
 [43] top-level scope
    @ REPL[17]:1

I'm crossposting this from the discourse as I don't know if this is necessarily a bug with GraphNeuralNetworks.jl or if the devs know a better alternative to do these kinds of processes.

Thanks!