I just found what I believe to be a bug with the GATConv layer. The bug appeared on SeaPearl.jl so I don't have a proper code to reproduce the bug yet, but I'll work on it early next week.
Description
The bug appears in the following context:
Working with a Flux.Chain containing at least 2 GATConv;
Loading the chain on a GPU with CUDA;
Forward passing on the chain.
When running the code, I receive the following error:
ERROR: LoadError: MethodError: update_batch_edge(::GATConv{NullGraph, Float32}, ::Vector{Vector{Int64}}, ::CuArray{Float32, 2}, ::CuArray{Float32, 2}, ::FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}) is ambiguous. Candidates:
update_batch_edge(g::GATConv, adj, E::AbstractMatrix{T} where T, X::AbstractMatrix{T} where T, u) in GeometricFlux at /home/pierre/Documents/Stage/GeometricFlux.jl/src/layers/conv.jl:308
update_batch_edge(mp::T, adj, E::CuArray{T, 2} where T, X::CuArray{T, 2} where T, u) where T<:MessagePassing in GeometricFlux at /home/pierre/Documents/Stage/GeometricFlux.jl/src/cuda/msgpass.jl:21
update_batch_edge(mp::T, adj, E::AbstractMatrix{T} where T, X::CuArray{T, 2} where T, u) where T<:MessagePassing in GeometricFlux at /home/pierre/Documents/Stage/GeometricFlux.jl/src/cuda/msgpass.jl:11
update_batch_edge(mp::T, adj, E::CuArray{T, 2} where T, X::AbstractMatrix{T} where T, u) where T<:MessagePassing in GeometricFlux at /home/pierre/Documents/Stage/GeometricFlux.jl/src/cuda/msgpass.jl:16
Possible fix, define
update_batch_edge(::T, ::Any, ::CuArray{T, 2} where T, ::CuArray{T, 2} where T, ::Any) where T<:GATConv
Stacktrace:
[1] propagate(gn::GATConv{NullGraph, Float32}, adj::Vector{Vector{Int64}}, E::CuArray{Float32, 2}, V::CuArray{Float32, 2}, u::FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, naggr::Function, eaggr::Nothing, vaggr::Nothing)
@ GeometricFlux ~/Documents/Stage/GeometricFlux.jl/src/layers/gn.jl:65
[2] propagate(mp::GATConv{NullGraph, Float32}, adj::Vector{Vector{Int64}}, E::CuArray{Float32, 2}, X::CuArray{Float32, 2}, aggr::Function)
@ GeometricFlux ~/Documents/Stage/GeometricFlux.jl/src/layers/msgpass.jl:57
[3] propagate(mp::GATConv{NullGraph, Float32}, fg::FeaturedGraph{CuArray{Float32, 2}, CuArray{Float32, 2}, CuArray{Float32, 2}, FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}}, aggr::Function)
@ GeometricFlux ~/Documents/Stage/GeometricFlux.jl/src/layers/msgpass.jl:52
[4] (::GATConv{NullGraph, Float32})(fg::FeaturedGraph{CuArray{Float32, 2}, CuArray{Float32, 2}, CuArray{Float32, 2}, FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}})
@ GeometricFlux ~/Documents/Stage/GeometricFlux.jl/src/layers/conv.jl:345
[5] applychain(fs::Tuple{GATConv{NullGraph, Float32}}, x::FeaturedGraph{CuArray{Float32, 2}, CuArray{Float32, 2}, CuArray{Float32, 2}, FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}}) (repeats 2 times)
@ Flux ~/.julia/packages/Flux/goUGu/src/layers/basic.jl:36
[6] (::Chain{Tuple{GATConv{NullGraph, Float32}, GATConv{NullGraph, Float32}}})(x::FeaturedGraph{CuArray{Float32, 2}, CuArray{Float32, 2}, FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}})
@ Flux ~/.julia/packages/Flux/goUGu/src/layers/basic.jl:38
(Call stack truncated because there is a lot of SeaPearl related code)
Explanation
After tracking the bug, I found out what is happening:
during the first pass in a GATConv, everything works fine, but after the computation, in
propagate(gn, adj, E, V, u, naggr, eaggr, vaggr) @ GeometricFlux /GeometricFlux.jl/src/layers/gn.jl:65 the ef field of the FeaturedGraph is filled with E::CuArray despite it being a computation side product for the vertex features.
as a result the FeaturedGraph now has a CuArray instead of a FillArray for the ef field.
during the second pass, the update_batch_edge is called with 2 CuArray, which isn't catched in the file cuda/conv.jl which only has
I just found what I believe to be a bug with the
GATConv
layer. The bug appeared on SeaPearl.jl so I don't have a proper code to reproduce the bug yet, but I'll work on it early next week.Description
The bug appears in the following context:
Flux.Chain
containing at least 2GATConv
;CUDA
;When running the code, I receive the following error:
(Call stack truncated because there is a lot of SeaPearl related code)
Explanation
After tracking the bug, I found out what is happening:
GATConv
, everything works fine, but after the computation, inpropagate(gn, adj, E, V, u, naggr, eaggr, vaggr) @ GeometricFlux /GeometricFlux.jl/src/layers/gn.jl:65
theef
field of theFeaturedGraph
is filled withE::CuArray
despite it being a computation side product for the vertex features.FeaturedGraph
now has aCuArray
instead of aFillArray
for theef
field.update_batch_edge
is called with 2CuArray
, which isn't catched in the filecuda/conv.jl
which only hasPotential Fixes
The potential fixes I see are either:
ef
field forGATConv
=> most logical, but not entirely sure how to do it;What are your thoughts about this ?