CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
220 stars 46 forks source link

Graph classification: multiple graphs associated with a common label #282

Closed msainsburydale closed 2 months ago

msainsburydale commented 1 year ago

Hello, many thanks for a great package.

My task can be summarised as graph classification, where (i) multiple graphs are associated with a common label, and where (ii) all graphs have the same structure (i.e., only the node data changes between graphs).

For example, a single input instance may be:

using GraphNeuralNetworks
d = 1                     # dimension of response variable
n = 100                   # number of nodes in the graph
e = 200                   # number of edges in the graph
m = 30                    # number of replicates of the graph
g = rand_graph(n, e)      # fixed structure for all graphs

# Multiple graphs: these would all be associated with a 
# single label that we would like to predict
graphs = [GNNGraph(g; ndata = rand(d, n)) for _ in 1:m]

The usual approach is to batch the graphs into a single super graph:

using Flux: batch
batch(graphs)
# GNNGraph:
#   num_nodes: 3000
#   num_edges: 6000
#   num_graphs: 30
#   ndata:
#   x = 1×3000 Matrix{Float64}

This approach is natural when the graph structure varies between graphs. However, it is inefficient when the graphs have a fixed structure (particularly with respect to memory, but presumably also in terms of performing the required operations during the propagation and readout modules).

It would be more efficient to use a single graph with ndata storing the replicated data in, for example, the third dimension:

GNNGraph(g; ndata = rand(d, n, m))

This gives an error, since ndata should contain matrices only, I think arrays with the last dimension equal to the number of nodes.

Do you have any suggestions for how best to proceed, in a way that aligns with the philosophy of the package?

CarloLucibello commented 1 year ago

This is an interesting application. ndata wants the last dimension to be the same as the number of nodes, but the support is not limited to matrices, so you can add node features as

g.ndata.x = rand(Float32, d, m, n)

I'm not sure though that graph convolutions will handle properly these tensors. Ideally they should, so let me know how it goes.

msainsburydale commented 1 year ago

Thanks, that's very helpful.

The approach doesn't work directly "out of the box", at least for GraphConv, but it was not too difficult to get it working. GraphConv expects the node data to be an AbstractMatrix, and the method for matrices doesn't work when the data is a three-dimensional array. So, I defined a method for three-dimensional arrays as follows, making use of batched_mul.

using GraphNeuralNetworks
import GraphNeuralNetworks: GraphConv

function (l::GraphConv)(g::GNNGraph, x::A) where A <: AbstractArray{T, 3} where {T}
    check_num_nodes(g, x)
    m = GraphNeuralNetworks.propagate(copy_xj, g, l.aggr, xj = x)
    l.σ.(l.weight1 ⊠ x .+ l.weight2 ⊠ m .+ l.bias) # ⊠ is shorthand for batched_mul
end

I did some testing with the same simple example discussed above, focusing mainly on the dimensions of the output as a sanity check.

d = 2                      # dimension of response variable
n = 100                    # number of nodes in the graph
e = 200                    # number of edges in the graph
m = 30                     # number of replicates of the graph
g = rand_graph(n, e)       # fixed structure for all graphs
g.ndata.x = rand(d, m, n)  # node data varies between graphs

# One layer only:
out = 16
l = GraphConv(d => out)
l(g)
size(l(g)) # (16, 30, 100)

# Propagation and global-pooling modules:
gnn = GNNChain(
    GraphConv(d => out),
    GraphConv(out => out),
    GraphConv(out => out),
    GlobalPool(+)
)
gnn(g)
u = gnn(g).gdata.u
size(u)    # (16, 30, 1) 

The pooled features are a three-dimenisonal array of size out × m × 1, which is very close to the format of the pooled features one would obtain when "batching" the graph replicates into a single supergraph (in that case, the the pooled features are a matrix of size out × m). But I suppose that Flux.flatten can be added to the full network architecture to deal with this inconsistency.

Thanks again for your help.

CarloLucibello commented 1 year ago

yes something like

gnn = GNNChain(
    GraphConv(d => out),
    GraphConv(out => out),
    GraphConv(out => out),
    GlobalPool(+),
        x -> reshape(x, size(x)[1:end-1])
)

should work