FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.53k stars 610 forks source link

Missing functionalities - Metal with Conv and ConvTranspose layers #2278

Open 4SAnalyticsnModelling opened 1 year ago

4SAnalyticsnModelling commented 1 year ago

I was getting ERROR: TaskFailedException while trying the following codes:

using Flux, Metal, GPUArraysCore;
GPUArraysCore.allowscalar(false)
Metal.functional()
gpu_backend = "Metal"
Flux.GPU_BACKEND
x_ = Metal.mtl(rand(Float32, 10, 10, 10, 3))
m1 = Conv((3, 3), 10=>1)
m2 = Flux.gpu(m1)
m2(x_)

I get the same error when I try a ConvTranspose layer instead of a Conv layer

m1 = ConvTranspose((3, 3), 10=>1)
m2 = Flux.gpu(m1)
m2(x_)

Here's my

Metal.versioninfo()

macOS 13.4.0, Darwin 22.5.0

Toolchain:

Julia packages:

1 device:

And here's my error trace for the first chunk of codes:

ERROR: TaskFailedException

nested task error: TaskFailedException

    nested task error: Scalar indexing is disallowed.
    Invocation of getindex resulted in scalar indexing of a GPU array.
    This is typically caused by calling an iterating implementation of a method.
    Such implementations *do not* execute on the GPU, but very slowly on the CPU,
    and therefore are only permitted from the REPL for prototyping purposes.
    If you did intend to index this array, annotate the caller with @allowscalar.
    Stacktrace:
     [1] error(s::String)
       @ Base ./error.jl:35
     [2] assertscalar(op::String)
       @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:103
     [3] getindex(::MtlArray{Float32, 5}, ::Int64, ::Int64, ::Int64, ::Int64, ::Vararg{Int64})
       @ GPUArrays ~/.julia/packages/GPUArrays/t0LfC/src/host/indexing.jl:9
     [4] getindex
       @ ./subarray.jl:286 [inlined]
     [5] im2col!(col::SubArray{Float32, 2, MtlArray{Float32, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Int64}, true}, x::SubArray{Float32, 4, MtlArray{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Int64}, true}, cdims::DenseConvDims{3, 3, 3, 6, 3})
       @ NNlib ~/.julia/packages/NNlib/Fg3DQ/src/impl/conv_im2col.jl:228
     [6] macro expansion
       @ ~/.julia/packages/NNlib/Fg3DQ/src/impl/conv_im2col.jl:51 [inlined]
     [7] (::NNlib.var"#1110#threadsfor_fun#635"{NNlib.var"#1110#threadsfor_fun#634#636"{MtlArray{Float32, 3}, Float32, Float32, SubArray{Float32, 5, MtlArray{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, MtlArray{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64, UnitRange{Int64}}})(tid::Int64; onethread::Bool)
       @ NNlib ./threadingconstructs.jl:163
     [8] #1110#threadsfor_fun
       @ ./threadingconstructs.jl:130 [inlined]
     [9] (::Base.Threads.var"#1#2"{NNlib.var"#1110#threadsfor_fun#635"{NNlib.var"#1110#threadsfor_fun#634#636"{MtlArray{Float32, 3}, Float32, Float32, SubArray{Float32, 5, MtlArray{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, MtlArray{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64, UnitRange{Int64}}}, Int64})()
       @ Base.Threads ./threadingconstructs.jl:108

...and 2 more exceptions.

Stacktrace:
 [1] threading_run(fun::NNlib.var"#1110#threadsfor_fun#635"{NNlib.var"#1110#threadsfor_fun#634#636"{MtlArray{Float32, 3}, Float32, Float32, SubArray{Float32, 5, MtlArray{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, MtlArray{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64, UnitRange{Int64}}}, static::Bool)
   @ Base.Threads ./threadingconstructs.jl:120
 [2] macro expansion
   @ ./threadingconstructs.jl:168 [inlined]
 [3] #conv_im2col!#633
   @ ~/.julia/packages/NNlib/Fg3DQ/src/impl/conv_im2col.jl:47 [inlined]
 [4] conv_im2col!(y::SubArray{Float32, 5, MtlArray{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{Float32, 5, MtlArray{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, cdims::DenseConvDims{3, 3, 3, 6, 3})
   @ NNlib ~/.julia/packages/NNlib/Fg3DQ/src/impl/conv_im2col.jl:23
 [5] (::NNlib.var"#305#309"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, 3, 3, 6, 3}, SubArray{Float32, 5, MtlArray{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, SubArray{Float32, 5, MtlArray{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
   @ NNlib ./threadingconstructs.jl:373

Stacktrace: [1] sync_end(c::Channel{Any}) @ Base ./task.jl:445 [2] macro expansion @ ./task.jl:477 [inlined] [3] conv!(out::MtlArray{Float32, 5}, in1::MtlArray{Float32, 5}, in2::Array{Float32, 5}, cdims::DenseConvDims{3, 3, 3, 6, 3}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ NNlib ~/.julia/packages/NNlib/Fg3DQ/src/conv.jl:205 [4] conv! @ ~/.julia/packages/NNlib/Fg3DQ/src/conv.jl:185 [inlined] [5] conv!(y::MtlArray{Float32, 4}, x::MtlArray{Float32, 4}, w::Array{Float32, 4}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ NNlib ~/.julia/packages/NNlib/Fg3DQ/src/conv.jl:145 [6] conv! @ ~/.julia/packages/NNlib/Fg3DQ/src/conv.jl:140 [inlined] [7] conv(x::MtlArray{Float32, 4}, w::Array{Float32, 4}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ NNlib ~/.julia/packages/NNlib/Fg3DQ/src/conv.jl:88 [8] conv @ ~/.julia/packages/NNlib/Fg3DQ/src/conv.jl:83 [inlined] [9] (::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}})(x::MtlArray{Float32, 4}) @ Flux ~/.julia/packages/Flux/n3cOc/src/layers/conv.jl:202 [10] top-level scope @ ~/Documents/ecosys_June19_2023/ecosys_julia_test_run/src/test_run.jl:253

ToucheSir commented 1 year ago

This is more of an issue for Metal.jl. in short, someone needs to add support for https://developer.apple.com/documentation/metalperformanceshaders/convolutional_neural_network_kernels and plumb it through to NNlib.jl.