CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
215 stars 46 forks source link

don't automatically batch when `getobs` from an array of graphs #170

Closed CarloLucibello closed 2 years ago

CarloLucibello commented 2 years ago

We currently have the following definitions

# DataLoader compatibility passing a vector of graphs and
# effectively using `batch` as a collated function.
MLUtils.numobs(data::Vector{<:GNNGraph}) = length(data)
MLUtils.getobs(data::Vector{<:GNNGraph}, i::Int) = data[i]
MLUtils.getobs(data::Vector{<:GNNGraph}, i) = Flux.batch(data[i])

that make for a nice interaction with MLUtils.DataLoader since we have automatic batching a.k.a. collating.

On the other hand, this doesn't well play well with other transformations in MLUtils, where one ends up with a batched graph even if they didn't want to:

julia> using MLUtils, GraphNeuralNetworks

julia> graphs = [rand_graph(10, 20)  for i=1:5]
5-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
 GNNGraph(10, 20)
 GNNGraph(10, 20)
 GNNGraph(10, 20)
 GNNGraph(10, 20)
 GNNGraph(10, 20)

julia> shuffleobs(graphs) |> getobs
GNNGraph:
    num_nodes = 50
    num_edges = 100
    num_graphs = 5

We should remove the automatic batching from here, and the DataLoader itself should provide automatic batching by calling MLUtils.batch

CarloLucibello commented 2 years ago

Let's wait for the outcome of the discussion in https://github.com/JuliaML/MLUtils.jl/issues/90 and for whatever DataLoader Flux is exporting to have a collate option before solving this issue.