Closed aurorarossi closed 4 months ago
Probably making https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/2681e216f8d024c9ae326338cc544989238d3cc0/src/GNNGraphs/transform.jl#L827 gpu friendly will require a dedicated kernel
I realized that this issue is already open here, see issue #161, so I will close it to avoid duplicates.
As discussed in the closed PR https://github.com/CarloLucibello/GraphNeuralNetworks.jl/pull/376, I investigated the reason why
MLUtils.unbatch
did not work on the GPU. I found out that it usesgetobs
andgetobs
in this package usesgetgraph
which is the one that gives the error when computed on a batch of graphs on the GPU.Here is a mwe:
and the error: