CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
210 stars 47 forks source link

Add `Flux.Recur` and `TGCN` #319

Closed aurorarossi closed 11 months ago

CarloLucibello commented 11 months ago

We should make the TGCN compatible with GNNChain. This requires overloading _applylayer to do what is done for GNNLayer: https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/1a4c62ba8351cf0744bdaac3ab7821ace25fa8de/src/layers/basic.jl#L158 https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/1a4c62ba8351cf0744bdaac3ab7821ace25fa8de/src/layers/basic.jl#L162

So this means defining

_applylayer(l::Recur{TGCNCell}, g::GNNGraph, x) = l(g, x)
_applylayer(l::Recur{TGCNCell}, g::GNNGraph) = l(g)
(l::Recur{TGCNCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))