Closed bicycle1885 closed 11 months ago
The problem is not in this repo but due to the gather!
implementation in NNlib.
The view xi
is of type
julia> typeof(xi)
SubArray{Float32, 2, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, false}
for which we have
julia> xi isa AnyCuArray
true
julia> xi isa AbstractGPUArray
false
julia> xi isa AnyGPUArray
true
so we are not hitting the specialization https://github.com/FluxML/NNlib.jl/blob/607de4b8fec751e1079d2822ac950028bb819c1c/src/gather.jl#L112
I think this can be solved by relaxing the signature to AnyGPUArray
in NNlib. I'll try and see what happens.
The following snippet causes scalar indexing error on GPU, which is reproducible on the latest release and the master branch.
I think this is a regression because it starts to happen since GraphNeuralNetworks.jl 0.6.8 (see below). I'm not sure which package actually causes this error, but I realized it when I updated GraphNeuralNetworks.jl.