CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
210 stars 47 forks source link

The constraint in Flux.batch(gs::AbstractVector{<:GNNHeteroGraph}) does not seem to be strong enough #341

Closed AarSeBail closed 10 months ago

AarSeBail commented 10 months ago

While I'd also like there to be no constraint, I do not believe that the current constraint is sufficient.

https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/066d8f50c88f4e558be50c9f45b7fbdea8bb6527/src/GNNGraphs/transform.jl#L585C3-L585C3

Consider for example the following:

using GraphNeuralNetworks
using Flux

gs = [rand_heterograph((:A =>10, :B => 14), ((:A, :to1, :A) => 5, (:A, :to2, :B) => 20)),
                   rand_heterograph((:A =>10, :B => 14), ((:A, :to1, :A) => 5))]

g = Flux.batch(gs)

The code errors with ERROR: KeyError: key (:A, :to2, :B) not found due to

num_edges = Dict(edge_t => sum(g.num_edges[edge_t] for g in gs) for edge_t in etypes)

I believe that the assertion needs to be something along the following lines.

# TODO remove these constraints
@assert ntypes == intersect([g.ntypes for g in gs]...)
@assert etypes == intersect([g.etypes for g in gs]...)