FluxML / GeometricFlux.jl

Geometric Deep Learning for Flux
https://fluxml.ai/GeometricFlux.jl/stable/
MIT License
348 stars 30 forks source link

GraphConv errors on graphs with isolated nodes #216

Open CarloLucibello opened 3 years ago

CarloLucibello commented 3 years ago

Exposed in #214

julia> A = [0 1 0
            1 0 0
            0 0 0];

julia> fg = FeaturedGraph(A, nf=rand(2,3));

julia> m = GraphConv(2 => 2);

julia> m(fg)
ERROR: ArgumentError: reducing over an empty collection is not allowed
Stacktrace:
  [1] _empty_reduce_error()
    @ Base ./reduce.jl:299
  [2] mapreduce_empty(f::Function, op::Function, T::Type)
    @ Base ./reduce.jl:342
  [3] reduce_empty(op::Base.MappingRF{GeometricFlux.var"#17#18"{GraphConv{NullGraph, Matrix{Float32}, Vector{Float32}}, Int64, Dict{Tuple{UInt32, UInt32}, UInt64}, FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Matrix{Float64}}, typeof(hcat)}, #unused#::Type{Int64})
    @ Base ./reduce.jl:329
  [4] reduce_empty_iter
    @ ./reduce.jl:355 [inlined]
  [5] mapreduce_empty_iter(f::Function, op::Function, itr::Vector{Int64}, ItrEltype::Base.HasEltype)
    @ Base ./reduce.jl:351
  [6] _mapreduce(f::GeometricFlux.var"#17#18"{GraphConv{NullGraph, Matrix{Float32}, Vector{Float32}}, Int64, Dict{Tuple{UInt32, UInt32}, UInt64}, FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Matrix{Float64}}, op::typeof(hcat), #unused#::IndexLinear, A::Vector{Int64})
    @ Base ./reduce.jl:400
  [7] _mapreduce_dim
    @ ./reducedim.jl:318 [inlined]
  [8] #mapreduce#672
    @ ./reducedim.jl:310 [inlined]
  [9] mapreduce
    @ ./reducedim.jl:310 [inlined]
 [10] apply_batch_message
    @ ~/.julia/dev/GeometricFlux/src/layers/msgpass.jl:51 [inlined]
 [11] #11
    @ ~/.julia/dev/GeometricFlux/src/layers/gn.jl:22 [inlined]
 [12] _mapreduce(f::GeometricFlux.var"#11#12"{GraphConv{NullGraph, Matrix{Float32}, Vector{Float32}}, Vector{Vector{Int64}}, FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Matrix{Float64}, FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, Dict{Tuple{UInt32, UInt32}, UInt64}}, op::typeof(hcat), #unused#::IndexLinear, A::UnitRange{Int64})
    @ Base ./reduce.jl:411
 [13] _mapreduce_dim
    @ ./reducedim.jl:318 [inlined]
 [14] #mapreduce#672
    @ ./reducedim.jl:310 [inlined]
 [15] mapreduce
    @ ./reducedim.jl:310 [inlined]
 [16] update_batch_edge
    @ ~/.julia/dev/GeometricFlux/src/layers/gn.jl:22 [inlined]
 [17] propagate(gn::GraphConv{NullGraph, Matrix{Float32}, Vector{Float32}}, adj::Vector{Vector{Int64}}, E::FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, V::Matrix{Float64}, u::FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, naggr::Function, eaggr::Nothing, vaggr::Nothing)
    @ GeometricFlux ~/.julia/dev/GeometricFlux/src/layers/gn.jl:63
 [18] propagate
    @ ~/.julia/dev/GeometricFlux/src/layers/msgpass.jl:63 [inlined]
 [19] GraphConv
    @ ~/.julia/dev/GeometricFlux/src/layers/conv.jl:162 [inlined]
 [20] (::GraphConv{NullGraph, Matrix{Float32}, Vector{Float32}})(fg::FeaturedGraph{Matrix{Int64}, Matrix{Float64}, FillArrays.Fill{Int64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, FillArrays.Fill{Int64, 1, Tuple{Base.OneTo{Int64}}}})
    @ GeometricFlux ~/.julia/dev/GeometricFlux/src/layers/conv.jl:166
 [21] top-level scope
    @ REPL[31]:1
 [22] top-level scope
    @ ~/.julia/packages/CUDA/lwSps/src/initialization.jl:52
yuehhua commented 3 years ago

Message-passing neural network relies on the connection between nodes. Messages are passed through edges. Therefore, the isolated node can be removed to avoid the error. I think this is not a bug and it is born this way.

eahenle commented 11 months ago

Removing the isolated node is not always correct--the readout/pooling may easily be affected by this. There is nothing in the original paper or the mathematical form given in the docs that suggests this should fail, and the error message is very unhelpful. This should at least be officially documented, and probably also have an equivalent to @assert !any(issequal(0), nv.(graphs))