FluxML / GeometricFlux.jl

Geometric Deep Learning for Flux
https://fluxml.ai/GeometricFlux.jl/stable/
MIT License
348 stars 30 forks source link

`DimensionMismatch` error using `GraphConv` layer on directed graphs #309

Open BatyLeo opened 2 years ago

BatyLeo commented 2 years ago

Hello, I'm new to GeometricFlux and currently experimenting with its features.

I just ran into the following issue: when I use a directed graph as input of a GraphConv layer, it raises a DimensionMismatch error I don't understand. Here is a minimum working example:

using Graphs
using GeometricFlux

nb_features = 5
g = path_digraph(10)
fg = FeaturedGraph(g, nf=randn(nb_features, nv(g)))
gc = GraphConv(nb_features=>20)
gc(fg)

Error log:

ERROR: DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 10 and 9")
Stacktrace:
  [1] _bcs1
    @ ./broadcast.jl:516 [inlined]
  [2] _bcs (repeats 2 times)
    @ ./broadcast.jl:510 [inlined]
  [3] broadcast_shape
    @ ./broadcast.jl:504 [inlined]
  [4] combine_axes
    @ ./broadcast.jl:499 [inlined]
  [5] _axes
    @ ./broadcast.jl:224 [inlined]
  [6] axes
    @ ./broadcast.jl:222 [inlined]
  [7] combine_axes
    @ ./broadcast.jl:499 [inlined]
  [8] _axes
    @ ./broadcast.jl:224 [inlined]
  [9] axes
    @ ./broadcast.jl:222 [inlined]
 [10] combine_axes(A::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(+), Tuple{Matrix{Float64}, Matrix{Float64}}}, Vector{Float32}}})
    @ Base.Broadcast ./broadcast.jl:500
 [11] instantiate
    @ ./broadcast.jl:281 [inlined]
 [12] materialize
    @ ./broadcast.jl:860 [inlined]
 [13] update(gc::GraphConv{Matrix{Float32}, Vector{Float32}, typeof(identity), typeof(+)}, m::Matrix{Float64}, x::Matrix{Float64})
    @ GeometricFlux ~/.julia/packages/GeometricFlux/eLaIW/src/layers/conv.jl:204
 [14] update_vertex
    @ ~/.julia/packages/GeometricFlux/eLaIW/src/layers/msgpass.jl:63 [inlined]
 [15] update_batch_vertex
    @ ~/.julia/packages/GeometricFlux/eLaIW/src/layers/gn.jl:16 [inlined]
 [16] propagate
    @ ~/.julia/packages/GeometricFlux/eLaIW/src/layers/gn.jl:52 [inlined]
 [17] propagate(gn::GraphConv{Matrix{Float32}, Vector{Float32}, typeof(identity), typeof(+)}, sg::SparseGraph{true, SparseArrays.SparseMatrixCSC{Float32, UInt32}, Vector{UInt32}, Int64}, E::Nothing, V::Matrix{Float64}, u::Nothing, naggr::Function, eaggr::Nothing, vaggr::Nothing)
    @ GeometricFlux ~/.julia/packages/GeometricFlux/eLaIW/src/layers/gn.jl:38
 [18] (::GraphConv{Matrix{Float32}, Vector{Float32}, typeof(identity), typeof(+)})(fg::FeaturedGraph{SparseGraph{true, SparseArrays.SparseMatrixCSC{Float32, UInt32}, Vector{UInt32}, Int64}, Matrix{Float64}, FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}})
    @ GeometricFlux ~/.julia/packages/GeometricFlux/eLaIW/src/layers/conv.jl:210

This code works fine if I replace path_digraph with path_graph, or if I use GCNConv instead of GraphConv.

yuehhua commented 2 years ago

Yeah, it seems currently not support message-passing network (including GraphConv) over directed graphs.