FluxML / GeometricFlux.jl

Geometric Deep Learning for Flux
https://fluxml.ai/GeometricFlux.jl/stable/
MIT License
348 stars 30 forks source link

Bug with CUDA + GATConv #185

Closed PierreTsr closed 2 years ago

PierreTsr commented 3 years ago

I just found what I believe to be a bug with the GATConv layer. The bug appeared on SeaPearl.jl so I don't have a proper code to reproduce the bug yet, but I'll work on it early next week.

Description

The bug appears in the following context:

When running the code, I receive the following error:

ERROR: LoadError: MethodError: update_batch_edge(::GATConv{NullGraph, Float32}, ::Vector{Vector{Int64}}, ::CuArray{Float32, 2}, ::CuArray{Float32, 2}, ::FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}) is ambiguous. Candidates:
  update_batch_edge(g::GATConv, adj, E::AbstractMatrix{T} where T, X::AbstractMatrix{T} where T, u) in GeometricFlux at /home/pierre/Documents/Stage/GeometricFlux.jl/src/layers/conv.jl:308
  update_batch_edge(mp::T, adj, E::CuArray{T, 2} where T, X::CuArray{T, 2} where T, u) where T<:MessagePassing in GeometricFlux at /home/pierre/Documents/Stage/GeometricFlux.jl/src/cuda/msgpass.jl:21
  update_batch_edge(mp::T, adj, E::AbstractMatrix{T} where T, X::CuArray{T, 2} where T, u) where T<:MessagePassing in GeometricFlux at /home/pierre/Documents/Stage/GeometricFlux.jl/src/cuda/msgpass.jl:11
  update_batch_edge(mp::T, adj, E::CuArray{T, 2} where T, X::AbstractMatrix{T} where T, u) where T<:MessagePassing in GeometricFlux at /home/pierre/Documents/Stage/GeometricFlux.jl/src/cuda/msgpass.jl:16
Possible fix, define
  update_batch_edge(::T, ::Any, ::CuArray{T, 2} where T, ::CuArray{T, 2} where T, ::Any) where T<:GATConv
Stacktrace:
  [1] propagate(gn::GATConv{NullGraph, Float32}, adj::Vector{Vector{Int64}}, E::CuArray{Float32, 2}, V::CuArray{Float32, 2}, u::FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, naggr::Function, eaggr::Nothing, vaggr::Nothing)
    @ GeometricFlux ~/Documents/Stage/GeometricFlux.jl/src/layers/gn.jl:65
  [2] propagate(mp::GATConv{NullGraph, Float32}, adj::Vector{Vector{Int64}}, E::CuArray{Float32, 2}, X::CuArray{Float32, 2}, aggr::Function)
    @ GeometricFlux ~/Documents/Stage/GeometricFlux.jl/src/layers/msgpass.jl:57
  [3] propagate(mp::GATConv{NullGraph, Float32}, fg::FeaturedGraph{CuArray{Float32, 2}, CuArray{Float32, 2}, CuArray{Float32, 2}, FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}}, aggr::Function)
    @ GeometricFlux ~/Documents/Stage/GeometricFlux.jl/src/layers/msgpass.jl:52
  [4] (::GATConv{NullGraph, Float32})(fg::FeaturedGraph{CuArray{Float32, 2}, CuArray{Float32, 2}, CuArray{Float32, 2}, FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}})
    @ GeometricFlux ~/Documents/Stage/GeometricFlux.jl/src/layers/conv.jl:345
  [5] applychain(fs::Tuple{GATConv{NullGraph, Float32}}, x::FeaturedGraph{CuArray{Float32, 2}, CuArray{Float32, 2}, CuArray{Float32, 2}, FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}}) (repeats 2 times)
    @ Flux ~/.julia/packages/Flux/goUGu/src/layers/basic.jl:36
  [6] (::Chain{Tuple{GATConv{NullGraph, Float32}, GATConv{NullGraph, Float32}}})(x::FeaturedGraph{CuArray{Float32, 2}, CuArray{Float32, 2}, FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}})
    @ Flux ~/.julia/packages/Flux/goUGu/src/layers/basic.jl:38

(Call stack truncated because there is a lot of SeaPearl related code)

Explanation

After tracking the bug, I found out what is happening:

Potential Fixes

The potential fixes I see are either:

What are your thoughts about this ?

yuehhua commented 2 years ago

Check new version here. If you have any questions, it's welcome to reopen the issue.