CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
210 stars 47 forks source link

`add_edges` for GNNHeteroGraph does not allow providing the number of nodes #334

Closed svilupp closed 10 months ago

svilupp commented 10 months ago

I think I noticed a bug when adding new edges (with new nodes) into GNNHeteroGraph.

Documentation for add_edges says:

If the edge type is not already present in the graph, it is added. If it involves new node types, they are added to the graph as well. In this case, a dictionary or named tuple of num_nodes can be passed to specify the number of nodes of the new types, otherwise the number of nodes is inferred from the maximum

Expectation: When adding new edges with a new node, I want to provide the number of the new nodes via num_nodes kwarg

Behavior: When I provide num_nodes, I get

ERROR: KeyError: key :actor not found Stacktrace: [1] getindex @ GraphNeuralNetworks.GNNGraphs ./dict.jl:496 [inlined] [2] add_edges(g::GNNHeteroGraph{…}, edge_t::Tuple{…}, snew::Vector{…}, tnew::Vector{…}; edata::Nothing, num_nodes::Dict{…}) @ GraphNeuralNetworks.GNNGraphs ~/Documents/GitHub/GraphNeuralNetworks.jl/src/GNNGraphs/transform.jl:195 [3] add_edges @ GraphNeuralNetworks.GNNGraphs ~/Documents/GitHub/GraphNeuralNetworks.jl/src/GNNGraphs/transform.jl:162 [inlined] [4] #add_edges#178 @ GraphNeuralNetworks.GNNGraphs ~/Documents/GitHub/GraphNeuralNetworks.jl/src/GNNGraphs/transform.jl:160 [inlined] [5] top-level scope

MWE:

using GraphNeuralNetworks

g = GNNHeteroGraph((:user, :rate, :movie) => ([1,1,2,3], [7,13,5,7]); num_nodes=Dict(:user=>10,:movie=>15))
g = add_edges(g, (:actor, :like, :movie) => ([1,2,3,3,3], [3,5,1,9,4]); num_nodes=Dict(:actor=>10))
# ERROR: KeyError: key :actor not found

# NamedTuple fails too
g = add_edges(g, (:actor, :like, :movie) => ([1,2,3,3,3], [3,5,1,9,4]);num_nodes=(;:actor=>10,))
# ERROR: KeyError: key :actor not found

# this works:
g = add_edges(g, (:actor, :like, :movie) => ([1,2,3,3,3], [3,5,1,9,4]))

Versioninfo:

julia> versioninfo() Julia Version 1.10.0-beta2 Commit a468aa198d0 (2023-08-17 06:27 UTC) Build Info: Official https://julialang.org/ release Platform Info: OS: macOS (arm64-apple-darwin22.4.0) CPU: 8 × Apple M1 Pro WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1) Threads: 8 on 6 virtual cores Environment: JULIA_EDITOR = code JULIA_NUM_THREADS = 8

Package version: 0.6.11 (current main branch)

svilupp commented 10 months ago

I'll open a PR with a fix later.

CarloLucibello commented 10 months ago

Thanks, a PR would be appreciated. Heterographs support is still quite rough and needs more love.