Closed msainsburydale closed 2 months ago
This is an interesting application. ndata
wants the last dimension to be the same as the number of nodes, but the support is not limited to matrices, so you can add node features as
g.ndata.x = rand(Float32, d, m, n)
I'm not sure though that graph convolutions will handle properly these tensors. Ideally they should, so let me know how it goes.
Thanks, that's very helpful.
The approach doesn't work directly "out of the box", at least for GraphConv
, but it was not too difficult to get it working. GraphConv
expects the node data to be an AbstractMatrix
, and the method for matrices doesn't work when the data is a three-dimensional array. So, I defined a method for three-dimensional arrays as follows, making use of batched_mul
.
using GraphNeuralNetworks
import GraphNeuralNetworks: GraphConv
function (l::GraphConv)(g::GNNGraph, x::A) where A <: AbstractArray{T, 3} where {T}
check_num_nodes(g, x)
m = GraphNeuralNetworks.propagate(copy_xj, g, l.aggr, xj = x)
l.σ.(l.weight1 ⊠ x .+ l.weight2 ⊠ m .+ l.bias) # ⊠ is shorthand for batched_mul
end
I did some testing with the same simple example discussed above, focusing mainly on the dimensions of the output as a sanity check.
d = 2 # dimension of response variable
n = 100 # number of nodes in the graph
e = 200 # number of edges in the graph
m = 30 # number of replicates of the graph
g = rand_graph(n, e) # fixed structure for all graphs
g.ndata.x = rand(d, m, n) # node data varies between graphs
# One layer only:
out = 16
l = GraphConv(d => out)
l(g)
size(l(g)) # (16, 30, 100)
# Propagation and global-pooling modules:
gnn = GNNChain(
GraphConv(d => out),
GraphConv(out => out),
GraphConv(out => out),
GlobalPool(+)
)
gnn(g)
u = gnn(g).gdata.u
size(u) # (16, 30, 1)
The pooled features are a three-dimenisonal array of size out × m × 1
, which is very close to the format of the pooled features one would obtain when "batching" the graph replicates into a single supergraph (in that case, the
the pooled features are a matrix of size out × m
). But I suppose that Flux.flatten
can be added to the full network architecture to deal with this inconsistency.
Thanks again for your help.
yes something like
gnn = GNNChain(
GraphConv(d => out),
GraphConv(out => out),
GraphConv(out => out),
GlobalPool(+),
x -> reshape(x, size(x)[1:end-1])
)
should work
Hello, many thanks for a great package.
My task can be summarised as graph classification, where (i) multiple graphs are associated with a common label, and where (ii) all graphs have the same structure (i.e., only the node data changes between graphs).
For example, a single input instance may be:
The usual approach is to batch the graphs into a single super graph:
This approach is natural when the graph structure varies between graphs. However, it is inefficient when the graphs have a fixed structure (particularly with respect to memory, but presumably also in terms of performing the required operations during the propagation and readout modules).
It would be more efficient to use a single graph with
ndata
storing the replicated data in, for example, the third dimension:This gives an error, since
ndata
should containmatrices only, I thinkarrays with the last dimension equal to the number of nodes.Do you have any suggestions for how best to proceed, in a way that aligns with the philosophy of the package?