Open CarloLucibello opened 4 months ago
hi @CarloLucibello as per our discussion I've set up a working example for Flux and then used this example to extend to GraphNeuralNetworks
.
here is my code:
using Flux, Random, Enzyme, GraphNeuralNetworks
rng = Random.default_rng()
loss(model, x) = sum(model(g, g.x))
model = GNNChain(GCNConv(2=>5),
BatchNorm(5),
x -> relu.(x),
Dense(5, 4))
x = randn(Float32, 2, 3);
g = rand_graph(3, 6)
g.ndata.x = x
grads_zygote = Flux.gradient(model->loss(model, x), model)[1]
dx = grads_enzyme = Flux.fmap(model) do x
x isa Array ? zero(x) : x
end
Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, grads_enzyme), Duplicated(x, dx))
the last line results in error message indicating lack of Enzyme
's Duplicated
compatibility with GNNChain
(contrary to Flux
's Chain
).
ERROR: MethodError: no method matching Duplicated(::Int64, ::GNNChain{Tuple{GCNConv{…}, BatchNorm{…}, var"#47#48", Dense{…}}})
Closest candidates are:
Duplicated(::T1, ::T1) where T1
@ EnzymeCore ~/.julia/packages/EnzymeCore/XBDTI/src/EnzymeCore.jl:64
Duplicated(::T1, ::T1, ::Bool) where T1
@ EnzymeCore ~/.julia/packages/EnzymeCore/XBDTI/src/EnzymeCore.jl:64
Stacktrace:
[1] top-level scope
@ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:56
Some type information was truncated. Use `show(err)` to see complete types.
what is the best way to approach the issue?
I was looking into EnzymeCore/
line 65 and I think the method should be extended from T1
to GNNChain
but I can't find T1
definition anywhere. any ideas?
The example scripts has some bugs, for instance the shape of dx
was not correct. Here the corrected script with also a simplified model
using Flux, Random, Enzyme, GraphNeuralNetworks
loss(model, x) = sum(model(g, x))
model = GraphConv(2 => 5)
x = randn(Float32, 2, 3);
g = rand_graph(3, 6)
grads_zygote = Flux.gradient(loss, model, x)
dmodel = Flux.fmap(model) do x
x isa Array ? zero(x) : x
end
dx = zero(x)
Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Duplicated(x, dx))
Enzyme throws an error here as well. The fundamental blocks of message passing should be tested, operations would be tested, i.e. the operations defined or used in
https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/src/msgpass.jl
Most probably we will need a rule for gather
and scatter
, so I would test those operations first.
Actually the enzyme rules for gather and scatter are already in NNlib: https://github.com/FluxML/NNlib.jl/blob/master/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl
Thank you for the code @CarloLucibello.
I've tried it after loading NNlibEnzymeCore
extension but I'm not sure if the Enzyme.autodiff
function call in the above example snippet is correct as it yields StackOverFlow error.
This usually happens "when the call stack exceeds its maximum size, typically due to infinite recursion" (source). I'm wondering if that means that Enzyme.autodiff
call is recursive in that example?
On the positive side, I don't experience MethodError
anymore.
using Flux, Random, Enzyme, GraphNeuralNetworks
using NNlib, EnzymeCore
rng = Random.default_rng()
loss(model, x) = sum(model(g, x))
model = GraphConv(2 => 5)
x = randn(Float32, 2, 3);
g = rand_graph(3, 6)
dmodel = Flux.fmap(model) do x
x isa Array ? zero(x) : x
end
dx = zero(x)
Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Duplicated(x, dx))
ERROR: StackOverflowError:
Stacktrace:
[1] getproperty
@ ./Base.jl:32 [inlined]
[2] unwrap_unionall
@ ./essentials.jl:379 [inlined]
[3] fieldnames
@ ./reflection.jl:169 [inlined]
[4] augmented_julia_fieldnames_8170wrap
@ ./reflection.jl:0
[5] macro expansion
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
[6] enzyme_call
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
[7] AugmentedForwardThunk
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009 [inlined]
[8] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(fieldnames), df::Nothing, primal_1::Type{…}, shadow_1_1::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/rules/jitrules.jl:179
--- the last 6 lines are repeated 4611 more times ---
[27675] check_num_nodes
@ ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/utils.jl:2
[27676] GraphConv
@ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:306 [inlined]
Maybe let me paste here whole stacktrace.
Stacktrace:
[1] getproperty
@ ./Base.jl:32 [inlined]
[2] unwrap_unionall
@ ./essentials.jl:379 [inlined]
[3] fieldnames
@ ./reflection.jl:169 [inlined]
[4] augmented_julia_fieldnames_8170wrap
@ ./reflection.jl:0
[5] macro expansion
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
[6] enzyme_call
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
[7] AugmentedForwardThunk
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009 [inlined]
[8] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(fieldnames), df::Nothing, primal_1::Type{…}, shadow_1_1::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/rules/jitrules.jl:179
--- the last 6 lines are repeated 4611 more times ---
[27675] check_num_nodes
@ ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/utils.jl:2
[27676] GraphConv
@ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:306 [inlined]
[27677] GraphConv
@ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:0 [inlined]
[27678] augmented_julia_GraphConv_6402_inner_1wrap
@ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:0
[27679] macro expansion
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
[27680] enzyme_call
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
[27681] AugmentedForwardThunk
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009 [inlined]
[27682] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::GraphConv{…}, df::GraphConv{…}, primal_1::GNNGraph{…}, shadow_1_1::Nothing, primal_2::Matrix{…}, shadow_2_1::Matrix{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/rules/jitrules.jl:179
[27683] loss
@ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:40 [inlined]
[27684] loss
@ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:0 [inlined]
[27685] augmented_julia_loss_8465_inner_1wrap
@ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:0
[27686] macro expansion
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
[27687] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056
[27688] (::Enzyme.Compiler.AugmentedForwardThunk{…})(::Const{…}, ::Duplicated{…}, ::Vararg{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009
[27689] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(loss)}, ::Type{Active}, ::Duplicated{GraphConv{…}}, ::Vararg{Any})
@ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:198
[27690] autodiff(::ReverseMode{false, FFIABI}, ::typeof(loss), ::Type, ::Duplicated{GraphConv{…}}, ::Vararg{Any})
@ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:224
Some type information was truncated. Use `show(err)` to see complete types.
Interestingly, this is first thing in the stacktrace causing the issue.
function (l::GraphConv)(g::AbstractGNNGraph, x)
check_num_nodes(g, x)
xj, xi = expand_srcdst(g, x)
m = propagate(copy_xj, g, l.aggr, xj = xj)
x = l.σ.(l.weight1 * xi .+ l.weight2 * m .+ l.bias)
return x
end
function check_num_nodes(g::GNNGraph, x::AbstractArray)
@assert g.num_nodes==size(x, ndims(x)) "Got $(size(x, ndims(x))) as last dimension size instead of num_nodes=$(g.num_nodes)"
return true
end
Also linking some related issues/PRs for future testing purposes. Flux.jl PR #2392 EnzymeAD issue #805
Let's focus on propagate
, e.g.
f(x) = sum(propagate(copy_xj, g, +, xj = x))
dx = zero(x)
Enzyme.autodiff(Reverse, loss, Active, Duplicated(x, dx))
Here we will keep track of compatibility with Enzyme for taking gradients. First think is to collect a few examples to run.