CarloLucibello / GraphNeuralNetworks.jl

Graph Neural Networks in Julia
https://carlolucibello.github.io/GraphNeuralNetworks.jl/dev/
MIT License
209 stars 47 forks source link

`getgraph` not working on GPU #377

Closed aurorarossi closed 4 months ago

aurorarossi commented 4 months ago

As discussed in the closed PR https://github.com/CarloLucibello/GraphNeuralNetworks.jl/pull/376, I investigated the reason why MLUtils.unbatch did not work on the GPU. I found out that it uses getobs and getobs in this package uses getgraph which is the one that gives the error when computed on a batch of graphs on the GPU.

Here is a mwe:

using GraphNeuralNetworks, CUDA, MLUtils, Flux

gs = [rand_graph(10, 10) for _ in 1:2]
gs = gpu(gs)

g_batch = MLUtils.batch(gs)
g1 = getgraph(g_batch, 1)

and the error:

ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#broadcast_kernel#38")(::CUDA.CuKernelContext, ::CuDeviceVector{Bool, 1}, ::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(in), Tuple{Base.Broadcast.Extruded{CuDeviceVector{Int64, 1}, Tuple{Bool}, Tuple{Int64}}, CUDA.CuRefValue{Vector{Int64}}}}, ::Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(in), Tuple{Base.Broadcast.Extruded{CuDeviceVector{Int64, 1}, Tuple{Bool}, Tuple{Int64}}, CUDA.CuRefValue{Vector{Int64}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{CuDeviceVector{Int64, 1}, Tuple{Bool}, Tuple{Int64}}, CUDA.CuRefValue{Vector{Int64}}} which is not isbits.
    .2 is of type CUDA.CuRefValue{Vector{Int64}} which is not isbits.
      .val is of type Vector{Int64} which is not isbits.
CarloLucibello commented 4 months ago

Probably making https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/2681e216f8d024c9ae326338cc544989238d3cc0/src/GNNGraphs/transform.jl#L827 gpu friendly will require a dedicated kernel

aurorarossi commented 4 months ago

I realized that this issue is already open here, see issue #161, so I will close it to avoid duplicates.