denizyuret / Knet.jl

Koç University deep learning framework.
https://denizyuret.github.io/Knet.jl/latest
Other
1.43k stars 230 forks source link

"No good algo found" error when JULIA_NUM_THREADS > 1 #575

Open jonathan-laurent opened 4 years ago

jonathan-laurent commented 4 years ago

I am encountering a "no good algo found error" when running AlphaZero.jl on some machines with JULIA_NUM_THREADS > 1:

No good algo found.
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] perfChoose(::Array{Knet.cudnnConvolutionFwdAlgoPerf_t,1}, ::Int32) at /home/jonathan-laurent/.julia/packages/Knet/bTNMd/src/conv.jl:589
 [3] conv4_algo(::Knet.KnetArray{Float32,4}, ::Knet.KnetArray{Float32,4}, ::Knet.KnetArray{Float32,4}; handle::Ptr{Nothing}, o::Base.Iterators.Pairs{Symbol,Tuple{Int64,Int64},Tuple{Symbol},NamedTuple{(:padding,),Tuple{Tuple{Int64,Int64}}}}) at /home/jonathan-laurent/.julia/packages/Knet/bTNMd/src/conv.jl:524
 [4] conv4(::Knet.KnetArray{Float32,4}, ::Knet.KnetArray{Float32,4}; handle::Ptr{Nothing}, alpha::Int64, o::Base.Iterators.Pairs{Symbol,Tuple{Int64,Int64},Tuple{Symbol},NamedTuple{(:padding,),Tuple{Tuple{Int64,Int64}}}}) at /home/jonathan-laurent/.julia/packages/Knet/bTNMd/src/conv.jl:40
 [5] forw(::Function, ::AutoGrad.Param{Knet.KnetArray{Float32,4}}, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Symbol,Tuple{Int64,Int64},Tuple{Symbol},NamedTuple{(:padding,),Tuple{Tuple{Int64,Int64}}}}) at /home/jonathan-laurent/.julia/packages/AutoGrad/6QsMu/src/core.jl:66
 [6] #conv4#276 at ./none:0 [inlined]
 [7] (::AlphaZero.KNets.Conv)(::Knet.KnetArray{Float32,4}) at /home/jonathan-laurent/AlphaZero.jl/src/networks/knet/layers.jl:60
 [8] (::AlphaZero.KNets.Chain)(::Knet.KnetArray{Float32,4}) at /home/jonathan-laurent/AlphaZero.jl/src/networks/knet/layers.jl:19
 [9] forward(::ResNet{Game}, ::Knet.KnetArray{Float32,4}) at /home/jonathan-laurent/AlphaZero.jl/src/networks/knet.jl:149

Note that:

Here is the spec of one machine where the bug happens:

julia> include(Knet.dir("test","gpu.jl"))
Knet.gpuCount() = 9
Knet.gpu() = 7
Knet.tk = ["/usr/local/cuda"]
Knet.libknet8 = "/home/jonathan-laurent/.julia/packages/Knet/bTNMd/deps/libknet8"
Knet.cudartfound = true
Knet.cudaRuntimeVersion = 10010
Knet.cudaDriverVersion = 10010
Knet.cudaGetDeviceCount() = 9
Knet.cudaGetDevice() = 7
Knet.cudaMemGetInfo() = (33719648256, 34058272768)
Knet.cudaDeviceSynchronize() = nothing
Knet.nvmlfound = true
Knet.nvmlDriverVersion = "418.67"
Knet.nvmlVersion = "10.418.67"
Knet.nvmlDeviceGetMemoryInfo() = (34058272768, 33719648256, 338624512)
Knet.cublashandle() = Ptr{Nothing} @0x000000000c262350
Knet.cublasVersion = 10200
Knet.cudnnhandle() = Ptr{Nothing} @0x000000000200ff40
Knet.cudnnVersion = 7605
Knet.dir() = "/home/jonathan-laurent/.julia/packages/Knet/bTNMd"
readdir(Knet.dir("deps")) = [".deprecated", ".gitignore", "Makefile", "README.windows", "build.jl", "build.log", "cuda01.cu", "cuda01.jl", "cuda01.o", "cuda1.cu", "cuda1.jl", "cuda1.o", "cuda11.cu", "cuda11.jl", "cuda11.o", "cuda12.cu", "cuda12.jl", "cuda12.o", "cuda13.cu", "cuda13.jl", "cuda13.o", "cuda14.jl", "cuda16.cu", "cuda16.jl", "cuda16.o", "cuda17.cu", "cuda17.jl", "cuda17.o", "cuda20.cu", "cuda20.jl", "cuda20.o", "cuda21.cu", "cuda21.jl", "cuda21.o", "cuda22.cu", "cuda22.jl", "cuda22.o", "gamma.jl", "libknet8.so"]

Here, the GPU 7 is a Tesla V100. I tried to configure Knet to use other GPUs but it does not change the error. Also, I am using Knet@1.3.5. Finally, note that although my program uses @Threads.spawn, only a single task is making calls to the GPU using Knet.

To try and replicate the problem:

git clone https://github.com/jonathan-laurent/AlphaZero.jl
cd AlphaZero.jl
git checkout b9b20f377d8d456ae280239b030e291bbe5c968f
export JULIA_NUM_THREADS=4
julia --project -e "import Pkg; Pkg.instantiate()"
julia --project --color=yes scripts/profile/distributed_self_play.jl
denizyuret commented 4 years ago

Can you add the following to your code and let me know if it solves the problem:

Knet.conv4_algo(w::KnetArray{T}, x::KnetArray{T}, y::KnetArray{T}; handle=Knet.cudnnhandle(), o...) where {T} = (0, Knet.cudnnWorkSpace())
Knet.conv4w_algo(w::KnetArray{T},x::KnetArray{T},dy::KnetArray{T},dw::KnetArray{T}; handle=Knet.cudnnhandle(), o...) where {T} = (0, Knet.cudnnWorkSpace())
Knet.conv4x_algo(w::KnetArray{T},x::KnetArray{T},dy::KnetArray{T},dx::KnetArray{T}; handle=Knet.cudnnhandle(), o...) where {T} = (0, Knet.cudnnWorkSpace())
jonathan-laurent commented 4 years ago

Now I am getting this:

MethodError: no method matching bytes(::Int64)
Closest candidates are:
  bytes(::KnetArray{T,N} where N) where T at /home/jonathan/.julia/packages/Knet/bTNMd/src/conv.jl:505
Stacktrace:
 [1] conv4(::KnetArray{Float32,4}, ::KnetArray{Float32,4}; handle::Ptr{Nothing}, alpha::Int64, o::Base.Iterators.Pairs{Symbol,Tuple{Int64,Int64},Tuple{Symbol},NamedTuple{(:padding,),Tuple{Tuple{Int64,Int64}}}}) at /home/jonathan/.julia/packages/Knet/bTNMd/src/conv.jl:41
 [2] forw(::Function, ::Param{KnetArray{Float32,4}}, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Symbol,Tuple{Int64,Int64},Tuple{Symbol},NamedTuple{(:padding,),Tuple{Tuple{Int64,Int64}}}}) at /home/jonathan/.julia/packages/AutoGrad/6QsMu/src/core.jl:66
 [3] #conv4#276 at ./none:0 [inlined]
 [4] (::AlphaZero.KNets.Conv)(::KnetArray{Float32,4}) at /home/jonathan/AlphaZero.jl/src/networks/knet/layers.jl:60
 [5] (::AlphaZero.KNets.Chain)(::KnetArray{Float32,4}) at /home/jonathan/AlphaZero.jl/src/networks/knet/layers.jl:19
 [6] forward(::ResNet{Game}, ::KnetArray{Float32,4}) at /home/jonathan/AlphaZero.jl/src/networks/knet.jl:149
denizyuret commented 4 years ago

Sorry about that, I edited the post with a small fix, please try again.

jonathan-laurent commented 4 years ago

Ok, so now I am getting:

cudnnConvolutionForward: 7: CUDNN_STATUS_MAPPING_ERROR
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] macro expansion at /home/jonathan-laurent/.julia/packages/Knet/bTNMd/src/gpu.jl:34 [inlined]
 [3] conv4(::KnetArray{Float32,4}, ::KnetArray{Float32,4}; handle::Ptr{Nothing}, alpha::Int64, o::Base.Iterators.Pairs{Symbol,Tuple{Int64,Int64},Tuple{Symbol},NamedTuple{(:padding,),Tuple{Tuple{Int64,Int64}}}}) at /home/jonathan-laurent/.julia/packages/Knet/bTNMd/src/conv.jl:41
 [4] forw(::Function, ::Param{KnetArray{Float32,4}}, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Symbol,Tuple{Int64,Int64},Tuple{Symbol},NamedTuple{(:padding,),Tuple{Tuple{Int64,Int64}}}}) at /home/jonathan-laurent/.julia/packages/AutoGrad/6QsMu/src/core.jl:66
 [5] #conv4#276 at ./none:0 [inlined]
 [6] (::AlphaZero.KNets.Conv)(::KnetArray{Float32,4}) at /home/jonathan-laurent/AlphaZero.jl/src/networks/knet/layers.jl:60
 [7] (::AlphaZero.KNets.Chain)(::KnetArray{Float32,4}) at /home/jonathan-laurent/AlphaZero.jl/src/networks/knet/layers.jl:19
 [8] forward(::ResNet{Game}, ::KnetArray{Float32,4}) at /home/jonathan-laurent/AlphaZero.jl/src/networks/knet.jl:149
denizyuret commented 4 years ago

cuDNN doc says:

CUDNN_STATUS_MAPPING_ERROR
An access to GPU memory space failed, which is usually caused by a failure to bind a texture.
To correct: prior to the function call, unbind any previously bound textures.
Otherwise, this may indicate an internal error/bug in the library.

At this point we may want to bring in @maleadt:

https://juliagpu.gitlab.io/CUDA.jl/usage/multigpu/ says the following about using multiple gpus with a single process:

"The CUDA memory pool is not device-aware yet, effectively breaking multi-gpu-single-process concurrency. Don't use this approach for serious work unless you can support with cross-device memory operations (e.g. with cuCtxEnablePeerAccess)."

Here is what it says for multiple threads:

"This approach is not recommended, as multi-threading is a fairly recent addition to the language and many packages, including those for Julia GPU programming, have not been made thread-safe yet. For now, the toolchain mimics the behavior of the CUDA runtime library and uses a single context across all devices."

The only current stable option seems to be one process per GPU, is this possible in your case?

P.S. I just realized that you may not be using multiple GPUs but just multiple threads. @maleadt may give us an idea whether this may cause an issue with the CUDA stack even when only a single GPU is used.

jonathan-laurent commented 4 years ago

I am not using multiple GPUs indeed. In fact, in the program that exhibits the bug above, all calls to the GPU happen within a single thread.

Note that this bug is not currently blocking for me so don't feel pressure to spend too much time on it.

Updated instructions to replicate (note that the bug does not happen on all machines):

git clone https://github.com/jonathan-laurent/AlphaZero.jl
cd AlphaZero.jl
git checkout distributed-knet-fix
export JULIA_NUM_THREADS=4
julia --project -e "import Pkg; Pkg.instantiate()"
julia --project --color=yes scripts/profile/distributed_self_play.jl
maleadt commented 4 years ago

The problem might be that Knet uses a single global handle with libraries like cuDNN, which is invalid:

The library is thread safe and its functions can be called from multiple host threads, as long as threads to do not share the same cuDNN handle simultaneously.

CUDA.jl takes care to create separate handles for each task/thread, e.g. https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/CUDNN.jl#L36-L69

denizyuret commented 4 years ago

Excellent, this means when I manage to integrate CUDA fully, this might work.

On Tue, Jul 14, 2020 at 11:43 AM Tim Besard notifications@github.com wrote:

The problem might be that Knet uses a single global handle with libraries like cuDNN, which is invalid https://docs.nvidia.com/deeplearning/sdk/cudnn-archived/cudnn_701/cudnn-user-guide/index.html#thread-safety :

The library is thread safe and its functions can be called from multiple host threads, as long as threads to do not share the same cuDNN handle simultaneously.

CUDA.jl takes care to create separate handles for each task/thread, e.g. https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/CUDNN.jl#L36-L69

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/denizyuret/Knet.jl/issues/575#issuecomment-658054520, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAN43JXJ5AOBI34XFRVRNG3R3QLE5ANCNFSM4OXIL5EQ .