JuliaGraphs / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://juliagraphs.org/GraphNeuralNetworks.jl/graphneuralnetworks/
MIT License
229 stars 45 forks source link

Batchnorm for Integers after GCNConv or GINConv on GPU #140

Closed casper2002casper closed 2 years ago

casper2002casper commented 2 years ago
using GraphNeuralNetworks
using Flux

function test_nn(nn, x)
    @show nn
    println("CPU")
    @show nn(x)
    println("GPU")
    x = Flux.gpu(x)
    nn = Flux.gpu(nn)
    @show nn(x)
end

x = GNNGraph(collect(1:6), collect(1:6), num_nodes = 6, ndata= rand(Int, 2, 6))
x2 = GNNGraph(collect(1:6), collect(1:6), num_nodes = 6, ndata= rand(2, 6))

test_nn(GNNChain(Dense(2, 2), BatchNorm(2)), x) #Works
test_nn(GNNChain(GraphConv(2 => 2), BatchNorm(2)), x) #Works
test_nn(GNNChain(GCNConv(2 => 2), BatchNorm(2)), x2) #Works
test_nn(GNNChain(GINConv(identity, 0), BatchNorm(2)), x2) #Works
test_nn(GNNChain(GCNConv(2 => 2), BatchNorm(2)), x) #Error
test_nn(GNNChain(GINConv(identity, 0), BatchNorm(2)), x) #Error

Error message for the non-working:

nn = GNNChain(GCNConv(2 => 2), BatchNorm(2))
CPU
nn(x) = GNNGraph:
    num_nodes = 6
    num_edges = 6
    ndata:
        x => (2, 6)
GPU
ERROR: LoadError: MethodError: no method matching batchnorm(::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Float32; cache=nothing, alpha=1, beta=0, eps=1.0f-5, training=false)
Closest candidates are:
  batchnorm(::CUDA.CuArray{T}, ::CUDA.CuArray{T}, ::CUDA.CuArray{T, 2}, ::CUDA.CuArray{T}, ::CUDA.CuArray{T}, ::Any; cache, alpha, beta, eps, training) where T<:Union{Float32, Float64} at ~/.julia/packages/NNlibCUDA/IeeBk/src/cudnn/batchnorm.jl:23
  batchnorm(::CUDA.CuArray{T}, ::CUDA.CuArray{T}, ::Union{CUDA.CuArray{T, 4}, CUDA.CuArray{T, 5}}, ::CUDA.CuArray{T}, ::CUDA.CuArray{T}, ::Any; cache, alpha, beta, eps, training) where T<:Union{Float32, Float64} at ~/.julia/packages/NNlibCUDA/IeeBk/src/cudnn/batchnorm.jl:27
Stacktrace:
 [1] (::BatchNorm{typeof(identity), CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})(x::CUDA.CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}, cache::Nothing)
   @ Flux.CUDAint ~/.julia/packages/Flux/BPPNj/src/cuda/cudnn.jl:9
 [2] (::BatchNorm{typeof(identity), CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})(x::CUDA.CuArray{Float64, 2, CUDA.Mem.DeviceBuffer})
   @ Flux.CUDAint ~/.julia/packages/Flux/BPPNj/src/cuda/cudnn.jl:6
 [3] applylayer(l::BatchNorm{typeof(identity), CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, g::GNNGraph{Tuple{CUDA.CuArray{Int64, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Int64, 1, CUDA.Mem.DeviceBuffer}, Nothing}})
   @ GraphNeuralNetworks ~/.julia/packages/GraphNeuralNetworks/KNr8R/src/layers/basic.jl:120
 [4] applychain (repeats 2 times)
   @ ~/.julia/packages/GraphNeuralNetworks/KNr8R/src/layers/basic.jl:133 [inlined]
 [5] (::GNNChain{Tuple{GCNConv{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, typeof(identity)}, BatchNorm{typeof(identity), CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}})(g::GNNGraph{Tuple{CUDA.CuArray{Int64, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Int64, 1, CUDA.Mem.DeviceBuffer}, Nothing}})
   @ GraphNeuralNetworks ~/.julia/packages/GraphNeuralNetworks/KNr8R/src/layers/basic.jl:140
 [6] macro expansion
   @ ./show.jl:1047 [inlined]
 [7] test_nn(nn::GNNChain{Tuple{GCNConv{Matrix{Float32}, Vector{Float32}, typeof(identity)}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, x::GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}})
   @ Main ~/Documents/testJulia/main.jl:11
 [8] top-level scope
   @ ~/Documents/testJulia/main.jl:21
in expression starting at /home/casperp/Documents/testJulia/main.jl:21
CarloLucibello commented 2 years ago

This is a Flux problem. But really you shouldn't use integer features, just convert them to float32 Float32.(rand(Int, 2, 6))

casper2002casper commented 2 years ago

FluxML/Flux.jl#1897