CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
215 stars 46 forks source link

outputsize for GNNChain #96

Closed tclements closed 2 years ago

tclements commented 2 years ago

This is a feature request: it'd be nice to have extend the functionality of Flux.outputsize to GNNChains. I imagine this could be applied to either a WithGraph or a GNNGraph and Tuple of inputsize. Here's a sketch of a MWE from the docs:

using Flux, Graphs, GraphNeuralNetworks

din, d, dout = 3, 4, 2 
g = rand_graph(10, 30)
X = randn(Float32, din, 10)
inputsize = size(X) 

model = GNNChain(GCNConv(din => d),
                 BatchNorm(d),
                 x -> relu.(x),
                 GCNConv(d => d, relu),
                 Dropout(0.5),
                 Dense(d, dout))
wg = WithGraph(model, g)

@assert GraphNeuralNetworks.outputsize(model, g, inputsize) == size(model(g,X))
@assert GraphNeuralNetworks.outputsize(wg, inputsize) == wg(X) 
CarloLucibello commented 2 years ago

Fortunately, Flux.outputsize implementation is so generic that it basically works with anything, you just need to wrap into closures in some cases:

julia> Flux.outputsize(wg, (3, 10))
(2, 10)

julia> Flux.outputsize(x -> model(g, x), (3, 10))
(2, 10)
tclements commented 2 years ago

Great, thank you!