CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
209 stars 47 forks source link

Enzyme.jl compatibility #389

Open CarloLucibello opened 4 months ago

CarloLucibello commented 4 months ago

Here we will keep track of compatibility with Enzyme for taking gradients. First think is to collect a few examples to run.

askorupka commented 3 months ago

hi @CarloLucibello as per our discussion I've set up a working example for Flux and then used this example to extend to GraphNeuralNetworks. here is my code:

using Flux, Random, Enzyme, GraphNeuralNetworks
rng = Random.default_rng()

loss(model, x) = sum(model(g, g.x))

model = GNNChain(GCNConv(2=>5), 
                    BatchNorm(5), 
                    x -> relu.(x), 
                    Dense(5, 4))
x = randn(Float32, 2, 3);
g = rand_graph(3, 6)

g.ndata.x = x

grads_zygote = Flux.gradient(model->loss(model, x), model)[1]

dx = grads_enzyme = Flux.fmap(model) do x
    x isa Array ? zero(x) : x
end

Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, grads_enzyme), Duplicated(x, dx))

the last line results in error message indicating lack of Enzyme's Duplicated compatibility with GNNChain (contrary to Flux's Chain).

ERROR: MethodError: no method matching Duplicated(::Int64, ::GNNChain{Tuple{GCNConv{…}, BatchNorm{…}, var"#47#48", Dense{…}}})

Closest candidates are:
  Duplicated(::T1, ::T1) where T1
   @ EnzymeCore ~/.julia/packages/EnzymeCore/XBDTI/src/EnzymeCore.jl:64
  Duplicated(::T1, ::T1, ::Bool) where T1
   @ EnzymeCore ~/.julia/packages/EnzymeCore/XBDTI/src/EnzymeCore.jl:64

Stacktrace:
 [1] top-level scope
   @ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:56
Some type information was truncated. Use `show(err)` to see complete types.

what is the best way to approach the issue? I was looking into EnzymeCore/ line 65 and I think the method should be extended from T1 to GNNChain but I can't find T1 definition anywhere. any ideas?

CarloLucibello commented 3 months ago

The example scripts has some bugs, for instance the shape of dx was not correct. Here the corrected script with also a simplified model

using Flux, Random, Enzyme, GraphNeuralNetworks

loss(model, x) = sum(model(g, x))

model = GraphConv(2 => 5)
x = randn(Float32, 2, 3);
g = rand_graph(3, 6)

grads_zygote = Flux.gradient(loss, model, x)

dmodel = Flux.fmap(model) do x
    x isa Array ? zero(x) : x
end

dx = zero(x)

Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Duplicated(x, dx))

Enzyme throws an error here as well. The fundamental blocks of message passing should be tested, operations would be tested, i.e. the operations defined or used in https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/src/msgpass.jl Most probably we will need a rule for gather and scatter, so I would test those operations first.

CarloLucibello commented 3 months ago

Actually the enzyme rules for gather and scatter are already in NNlib: https://github.com/FluxML/NNlib.jl/blob/master/ext/NNlibEnzymeCoreExt/NNlibEnzymeCoreExt.jl

askorupka commented 3 months ago

Thank you for the code @CarloLucibello. I've tried it after loading NNlibEnzymeCore extension but I'm not sure if the Enzyme.autodiff function call in the above example snippet is correct as it yields StackOverFlow error. This usually happens "when the call stack exceeds its maximum size, typically due to infinite recursion" (source). I'm wondering if that means that Enzyme.autodiff call is recursive in that example?

On the positive side, I don't experience MethodError anymore.

using Flux, Random, Enzyme, GraphNeuralNetworks
using NNlib, EnzymeCore
rng = Random.default_rng()

loss(model, x) = sum(model(g, x))

model = GraphConv(2 => 5)
x = randn(Float32, 2, 3);
g = rand_graph(3, 6)

dmodel = Flux.fmap(model) do x
    x isa Array ? zero(x) : x
end

dx = zero(x)

Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Duplicated(x, dx))

ERROR: StackOverflowError:
Stacktrace:
     [1] getproperty
       @ ./Base.jl:32 [inlined]
     [2] unwrap_unionall
       @ ./essentials.jl:379 [inlined]
     [3] fieldnames
       @ ./reflection.jl:169 [inlined]
     [4] augmented_julia_fieldnames_8170wrap
       @ ./reflection.jl:0
     [5] macro expansion
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
     [6] enzyme_call
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
     [7] AugmentedForwardThunk
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009 [inlined]
     [8] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(fieldnames), df::Nothing, primal_1::Type{…}, shadow_1_1::Nothing)
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/rules/jitrules.jl:179
--- the last 6 lines are repeated 4611 more times ---
 [27675] check_num_nodes
       @ ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/utils.jl:2
 [27676] GraphConv
       @ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:306 [inlined]
askorupka commented 3 months ago

Maybe let me paste here whole stacktrace.

Stacktrace:
     [1] getproperty
       @ ./Base.jl:32 [inlined]
     [2] unwrap_unionall
       @ ./essentials.jl:379 [inlined]
     [3] fieldnames
       @ ./reflection.jl:169 [inlined]
     [4] augmented_julia_fieldnames_8170wrap
       @ ./reflection.jl:0
     [5] macro expansion
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
     [6] enzyme_call
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
     [7] AugmentedForwardThunk
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009 [inlined]
     [8] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(fieldnames), df::Nothing, primal_1::Type{…}, shadow_1_1::Nothing)
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/rules/jitrules.jl:179
--- the last 6 lines are repeated 4611 more times ---
 [27675] check_num_nodes
       @ ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/utils.jl:2
 [27676] GraphConv
       @ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:306 [inlined]
 [27677] GraphConv
       @ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:0 [inlined]
 [27678] augmented_julia_GraphConv_6402_inner_1wrap
       @ ~/.julia/dev/GraphNeuralNetworks/src/layers/conv.jl:0
 [27679] macro expansion
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
 [27680] enzyme_call
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
 [27681] AugmentedForwardThunk
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009 [inlined]
 [27682] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::GraphConv{…}, df::GraphConv{…}, primal_1::GNNGraph{…}, shadow_1_1::Nothing, primal_2::Matrix{…}, shadow_2_1::Matrix{…})
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/rules/jitrules.jl:179
 [27683] loss
       @ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:40 [inlined]
 [27684] loss
       @ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:0 [inlined]
 [27685] augmented_julia_loss_8465_inner_1wrap
       @ ~/.julia/dev/GraphNeuralNetworks/enzyme_tests.jl:0
 [27686] macro expansion
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
 [27687] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…})
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056
 [27688] (::Enzyme.Compiler.AugmentedForwardThunk{…})(::Const{…}, ::Duplicated{…}, ::Vararg{…})
       @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5009
 [27689] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(loss)}, ::Type{Active}, ::Duplicated{GraphConv{…}}, ::Vararg{Any})
       @ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:198
 [27690] autodiff(::ReverseMode{false, FFIABI}, ::typeof(loss), ::Type, ::Duplicated{GraphConv{…}}, ::Vararg{Any})
       @ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:224
Some type information was truncated. Use `show(err)` to see complete types.
askorupka commented 3 months ago

Interestingly, this is first thing in the stacktrace causing the issue.

function (l::GraphConv)(g::AbstractGNNGraph, x)
    check_num_nodes(g, x)
    xj, xi = expand_srcdst(g, x)
    m = propagate(copy_xj, g, l.aggr, xj = xj)
    x = l.σ.(l.weight1 * xi .+ l.weight2 * m .+ l.bias)
    return x
end

function check_num_nodes(g::GNNGraph, x::AbstractArray)
    @assert g.num_nodes==size(x, ndims(x)) "Got $(size(x, ndims(x))) as last dimension size instead of num_nodes=$(g.num_nodes)"
    return true
end
askorupka commented 3 months ago

Also linking some related issues/PRs for future testing purposes. Flux.jl PR #2392 EnzymeAD issue #805

CarloLucibello commented 3 months ago

Let's focus on propagate, e.g.

f(x)  = sum(propagate(copy_xj, g, +, xj = x))
dx = zero(x)
Enzyme.autodiff(Reverse, loss, Active, Duplicated(x, dx))