CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
215 stars 46 forks source link

Some design issues #193

Closed YichengDWu closed 2 years ago

YichengDWu commented 2 years ago
  1. Node level tasks. We can use Withgraph to embed the graph. However, if I want to save the learned parameters and use them for another graph (transfer learning), it would be a problem since the graph is already embedded in the model. How should I save the following model?
    model = Chain(Dense(3,4), WithGraph(GCNConv(3,4), g), Dense(3,4))
  2. Graph level tasks. This is related to the first one. We would like to use GNNChain to accept a graph as input. But this assumes any model has to start with a GNN layer. If I change the model to be
    model = Chain(Dense(nin=>nin),
                  GNNChain(GraphConv(nin => nhidden, relu),
                  GraphConv(nhidden => nhidden, relu),
                  GlobalPool(mean), 
                  Dense(nhidden, 1)))

    in the example, it will not work.

But yeah, if we define a type on our own that should be to able to solve these issues

YichengDWu commented 2 years ago

Maybe GNNChain should be redesigned to be a generalization of Chain. It sequentially checks if the layer is a GNNLayer, if yes then feed (g, x) otherwise feed x. We will not need to use Chain anymore. Withgraph should be deprecated I guess.

CarloLucibello commented 2 years ago
  1. You can access the model wrapped by WithGraph. as follows:
    
    julia> g = rand_graph(10, 30)
    GNNGraph:
    num_nodes = 10
    num_edges = 30

julia> wg = WithGraph(GCNConv(4=>4), g) WithGraph{GCNConv{Matrix{Float32}, Vector{Float32}, typeof(identity)}, GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}(GCNConv(4 => 4), GNNGraph(10, 30), false)

julia> wg.model GCNConv(4 => 4)

CarloLucibello commented 2 years ago
  1. Maybe GNNChain should be redesigned to be a generalization of Chain. It sequentially checks if the layer is a GNNLayer, if yes then feed (g, x) otherwise feed x. We will not need to use Chain anymore. Withgraph should be deprecated I guess.

I'm not exactly sure what is the behavior you suggest, but what you describe here seems already what GNNChain does:

julia> using GraphNeuralNetworks, Flux

julia> g = rand_graph(10, 30)
GNNGraph:
    num_nodes = 10
    num_edges = 30

julia> x = randn(3, 10);

julia> model = GNNChain(Dense(3 => 5), GCNConv(5 => 5), Dense(5 => 1))
GNNChain(Dense(3 => 5), GCNConv(5 => 5), Dense(5 => 1))

julia> model(g, x)
1×10 Matrix{Float64}:
 0.603018  0.610886  -0.83652  0.595207  …  -0.808252  -0.203556  -0.530964

The docstring of GNNChain is the following:

help?> GNNChain
search: GNNChain

  GNNChain(layers...)
  GNNChain(name = layer, ...)

  Collects multiple layers / functions to be called in sequence on given input
  graph and input node features.

  It allows to compose layers in a sequential fashion as Flux.Chain does,
  propagating the output of each layer to the next one. In addition, GNNChain
  handles the input graph as well, providing it as a first argument only to
  layers subtyping the GNNLayer abstract type.

  GNNChain supports indexing and slicing, m[2] or m[1:end-1], and if names are
  given, m[:name] == m[1] etc.

  Examples
  ≡≡≡≡≡≡≡≡≡≡

  julia> m = GNNChain(GCNConv(2=>5), BatchNorm(5), x -> relu.(x), Dense(5, 4));

  julia> x = randn(Float32, 2, 3);

  julia> g = GNNGraph([1,1,2,3], [2,3,1,1]);

  julia> m(g, x)
  4×3 Matrix{Float32}:
    0.157941    0.15443     0.193471
    0.0819516   0.0503105   0.122523
    0.225933    0.267901    0.241878
   -0.0134364  -0.0120716  -0.0172505
YichengDWu commented 2 years ago

I know we can access the model in one WithGraph. I was making a case where you have a chain of WithGraph(s) mixed with other layers. I'm not sure if there is an easy way to save that Chain without the graph.

YichengDWu commented 2 years ago

You are right. GNNChain does what I suggestted. I was confused by the fact that a GNNChain could not be inside a Chain, but with Withgraph we can use Chain again

CarloLucibello commented 2 years ago

I was making a case where you have a chain of WithGraph(s) mixed with other layers. I'm not sure if there is an easy way to save that Chain without the graph

Maybe you want to use

model = WithGraph(GNNChain(...), g) 

instead and then save only the gnnchain

CarloLucibello commented 2 years ago

Closing this issue as there doesn't seem to be anything actionable, but we can reopen if you think anything is unsolved

YichengDWu commented 2 years ago

Thanks for the explanation! It looks like it is just plain wrong to use Chain and Withgraph together. I misunderstood the motivation of Withgraph and was also kinda misled by the fact that all examples of GNNChain start with a GNNLayer. It would be nice if there was a reminder in the documentation that says users should almost always use GNNChain over Chain.