I’m trying to model Graph NODEs integrating GraphNeuralNetworks.jl and OrdinaryDiffEq.jl. I am trying to learn both the neural network parameters as well as the weights of the edges, so I have to manually modify the Flux parameters during prediction. When I run the following MWE:
using Graphs, GraphNeuralNetworks, Flux, OrdinaryDiffEq, ComponentArrays, Zygote, SciMLSensitivity
time = 1:10
x0 = rand(9)
obs = rand(9,10)
fullGraph = GNNGraph(complete_digraph(3))
layer1 = GCNConv(3 => 10,tanh,use_edge_weight=true)
layer2 = GCNConv(10 => 3,use_edge_weight=true)
chain = GNNChain(layer1,layer2)
pinit = ComponentArray{Float32}(weights = rand(ne(fullGraph)),
layer1 = f64(layer1.weight),layer2 = f64(layer2.weight))
function predict(p)
fullGraph = GNNGraph(complete_digraph(3))
fullGraph = set_edge_weight(fullGraph,p.weights)
chain.layers[1].weight .= p.layer1
chain.layers[2].weight .= p.layer2
function nn!(du,u,p,t)
uGraph = reshape(u,(3,3))
dGraph = reshape(chain(fullGraph,uGraph),(3*3))
du .= dGraph
end
prob = ODEProblem(nn!,x0,(time[1],time[end]),saveat=time)
sol = solve(prob)
return Array(sol)
end
function loss_function(p)
pred = predict(p)
sum(abs2,pred .- obs)
end
Zygote.gradient(loss_function,pinit)
I'm crossposting this from the discourse as I don't know if this is necessarily a bug with GraphNeuralNetworks.jl or if the devs know a better alternative to do these kinds of processes.
I’m trying to model Graph NODEs integrating
GraphNeuralNetworks.jl
andOrdinaryDiffEq.jl
. I am trying to learn both the neural network parameters as well as the weights of the edges, so I have to manually modify the Flux parameters during prediction. When I run the following MWE:I get the following error:
I'm crossposting this from the discourse as I don't know if this is necessarily a bug with
GraphNeuralNetworks.jl
or if the devs know a better alternative to do these kinds of processes.Thanks!