FluxML / NNlib.jl

Neural Network primitives with multiple backends
Other
201 stars 121 forks source link

CUDNN bad param for backward pass with Float16 #515

Open DrChainsaw opened 3 years ago

DrChainsaw commented 3 years ago

I guess half-precision is not officially supported by Flux yet, but as it "almost" works now, perhaps it is worth looking into:

julia> cc = Flux.paramtype(Float16, Conv((3,3), 3 => 64)) |> gpu
Conv((3, 3), 3=>64)

julia> gradient(() -> sum(cc(ones(Float16, 3,3, 3, 1) |> gpu)), params(cc))
ERROR: CUDNNError: CUDNN_STATUS_BAD_PARAM (code 3)
Stacktrace:
  [1] throw_api_error(res::CUDA.CUDNN.cudnnStatus_t)
    @ CUDA.CUDNN E:\Programs\julia\.julia\packages\CUDA\3VnCC\lib\cudnn\error.jl:22
  [2] macro expansion
    @ E:\Programs\julia\.julia\packages\CUDA\3VnCC\lib\cudnn\error.jl:39 [inlined]
  [3] cudnnConvolutionBackwardFilter(handle::Ptr{Nothing}, alpha::Base.RefValue{Float32}, xDesc::CUDA.CUDNN.cudnnTensorDescriptor, x::CUDA.CuArray{Float16, 4}, dyDesc::CUDA.CUDNN.cudnnTensorDescriptor, dy::CUDA.CuArray{Float16, 4}, convDesc::CUDA.CUDNN.cudnnConvolutionDescriptor, algo::CUDA.CUDNN.cudnnConvolutionBwdFilterAlgo_t, workSpace::CUDA.CuArray{UInt8, 1}, workSpaceSizeInBytes::Int64, beta::Base.RefValue{Float32}, dwDesc::CUDA.CUDNN.cudnnFilterDescriptor, dw::CUDA.CuArray{Float16, 4})
    @ CUDA.CUDNN E:\Programs\julia\.julia\packages\CUDA\3VnCC\lib\utils\call.jl:26
  [4] macro expansion
    @ E:\Programs\julia\.julia\dev\NNlibCUDA\src\cudnn\conv.jl:91 [inlined]
  [5] macro expansion
    @ E:\Programs\julia\.julia\packages\CUDA\3VnCC\lib\utils\call.jl:144 [inlined]
  [6] ∇conv_filter!(dw::CUDA.CuArray{Float16, 4}, x::CUDA.CuArray{Float16, 4}, dy::CUDA.CuArray{Float16, 4}, cdims::DenseConvDims{2, (3, 3), 3, 64, (1, 1), (0, 0, 0, 0), (1, 1), false}; alpha::Int64, beta::Int64, algo::Int64)
    @ NNlibCUDA E:\Programs\julia\.julia\dev\NNlibCUDA\src\cudnn\conv.jl:91
  [7] ∇conv_filter!
    @ E:\Programs\julia\.julia\dev\NNlibCUDA\src\cudnn\conv.jl:81 [inlined]
  [8] #∇conv_filter#89
    @ E:\Programs\julia\.julia\packages\NNlib\ev8gq\src\conv.jl:116 [inlined]
  [9] ∇conv_filter
    @ E:\Programs\julia\.julia\packages\NNlib\ev8gq\src\conv.jl:114 [inlined]
 [10] #182
    @ E:\Programs\julia\.julia\packages\NNlib\ev8gq\src\conv.jl:229 [inlined]
 [11] Thunk
    @ E:\Programs\julia\.julia\packages\ChainRulesCore\fsJxJ\src\differentials\thunks.jl:98 [inlined]
 [12] unthunk
    @ E:\Programs\julia\.julia\packages\ChainRulesCore\fsJxJ\src\differentials\thunks.jl:99 [inlined]
 [13] wrap_chainrules_output
    @ E:\Programs\julia\.julia\packages\Zygote\i1R8y\src\compiler\chainrules.jl:41 [inlined]
 [14] map
    @ .\tuple.jl:215 [inlined]
 [15] map
    @ .\tuple.jl:216 [inlined]
 [16] wrap_chainrules_output
    @ E:\Programs\julia\.julia\packages\Zygote\i1R8y\src\compiler\chainrules.jl:42 [inlined]
 [17] ZBack
    @ E:\Programs\julia\.julia\packages\Zygote\i1R8y\src\compiler\chainrules.jl:77 [inlined]
 [18] Pullback
    @ E:\Programs\julia\.julia\packages\Flux\6o4DQ\src\layers\conv.jl:157 [inlined]
 [19] (::typeof(∂(λ)))(Δ::CUDA.CuArray{Float16, 4})
    @ Zygote E:\Programs\julia\.julia\packages\Zygote\i1R8y\src\compiler\interface2.jl:0
 [20] Pullback
    @ .\REPL[16]:1 [inlined]
 [21] (::typeof(∂(#11)))(Δ::Float16)
    @ Zygote E:\Programs\julia\.julia\packages\Zygote\i1R8y\src\compiler\interface2.jl:0
 [22] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(#11)), Zygote.Context})(Δ::Float16)
    @ Zygote E:\Programs\julia\.julia\packages\Zygote\i1R8y\src\compiler\interface.jl:255
 [23] gradient(f::Function, args::Zygote.Params)
    @ Zygote E:\Programs\julia\.julia\packages\Zygote\i1R8y\src\compiler\interface.jl:59
 [24] top-level scope
    @ REPL[16]:1

# With fewer neurons it works, workspace size issue?
julia> cc = Flux.paramtype(Float16, Conv((3,3), 3 => 3)) |> gpu
Conv((3, 3), 3=>3)

julia> gradient(() -> sum(cc(ones(Float16, 3,3, 3, 1) |> gpu)), params(cc))
Grads(...)

julia> typeof.(values(ans))
2-element Vector{DataType}:
 CUDA.CuArray{Float16, 4}
 CUDA.CuArray{Float16, 1}
DrChainsaw commented 3 years ago

@denizyuret, @maleadt sorry for poking.

Would it be easy for you to spot if the below function is correct usage of the Cudnn API? Perhaps this issue belongs in CUDA.jl instead if so?

https://github.com/FluxML/NNlibCUDA.jl/blob/b8baa3ddcdb7e8ca3c89153e5045209dc7bab7e0/src/cudnn/conv.jl#L86-L100

maleadt commented 3 years ago

Try running with JULIA_DEBUG=CUDNN (on latest CUDA.jl) and comparing the params to the error causes listed in https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionBackwardFilter.

DrChainsaw commented 3 years ago

Thanks. Thats some great debug output!

For some reason it did not print everything the first time, I had to rerun a couple of times until the relevant function appeared.

ERROR: CUDNNError: CUDNN_STATUS_BAD_PARAM┌ Debug: CuDNN (v8200) function cudnnConvolutionBackwardFilter() called:
│     handle: type=cudnnHandle_t; streamId=00000000B8F8EDC0;
│     alpha: type=CUDNN_DATA_FLOAT; val=1.000000;
│     xDesc: type=cudnnTensorDescriptor_t:
│         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
│         nbDims: type=int; val=4;
│         dimA: type=int; val=[1,3,3,3];
│         strideA: type=int; val=[27,9,3,1];
│     xData: location=dev; addr=0000000203C01000;
│     dyDesc: type=cudnnTensorDescriptor_t:
│         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
│         nbDims: type=int; val=4;
│         dimA: type=int; val=[1,64,1,1];
│         strideA: type=int; val=[64,1,1,1];
│     dyData: location=dev; addr=0000000203C01600;
│     convDesc: type=cudnnConvolutionDescriptor_t:
│         mode: type=cudnnConvolutionMode_t; val=CUDNN_CONVOLUTION (0);
│         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
│         mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
│         reorderType: type=int; val=0;
│         arrayLength: type=int; val=2;
│         padA: type=int; val=[0,0];
│         strideA: type=int; val=[1,1];
│         dilationA: type=int; val=[1,1];
│         groupCount: type=int; val=1;
│     algo: type=cudnnConvolutionBwdFilterAlgo_t; val=CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 (1);
│     workSpace: location=dev; addr=0000000203C03000;
│     workSpaceSizeInBytes: type=unsigned long long; val=5400;
│     beta: type=CUDNN_DATA_FLOAT; val=0.000000;
│     dwDesc: type=cudnnFilterDescriptor_t:
│         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
│         vect: type=int; val=0;
│         nbDims: type=int; val=4;
│         dimA: type=int; val=[64,3,3,3];
│         format: type=cudnnTensorFormat_t; val=CUDNN_TENSOR_NCHW (0);
│     dwData: location=dev; addr=0000000203C01A00;
│ Time: 2021-06-14T22:20:18.800688 (0d+0h+1m+4s since start)
│ Process=15972; Thread=14580; GPU=0; Handle=00000000E8189540; StreamId=00000000B8F8EDC0.
└ @ CUDA.CUDNN E:\Programs\julia\.julia\packages\CUDA\mVgLI\lib\cudnn\CUDNN.jl:123
 (code 3)
Stacktrace:

I could not spot anything violating the conditions listed for bad param. Looking at the table of supported algos, it seems like the datatypes for xDesc, dyDesc, convDesc and dwDesc represent the TRUE_HALF_CONFIG which is not listed as being supported by CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, but that should have yielded a CUDNN_STATUS_NOT_SUPPORTED, right?

Attempt to be slightly more than useless by listing each condition and what I think is the relevant part of output:

DrChainsaw commented 3 years ago

Did a bit of hacking around, and it seems like changing the algo to _WINOGRAD_NONFUSED (which seems to be the only one with support for TRUE_HALF_CONFIG ) resulted in CUDNN_STATUS_NOT_SUPPORTED, not sure why since all the fine print in the last column seems to be fulfilled.

Edit: NVM, I missed that the order of returned algos was not deterministic. _WINOGRAD_NONFUSED works. Btw, it says CUDA.CUDNN.CUDNN_STATUS_ALLOC_FAILED for the default workspace size but it seems to succeed anyways (data looks the same as with ALGO_1 and PSEUDO_HALF_CONFIG). This might explain why the small filter size works but not the large. Could changing the workspace size in cudnnFindConvolutionAlgorithmWorkspaceSize fix this, or perhaps accepting CUDNN_STATUS_ALLOC_FAILED in cudnnConvolutionAlgoPerfChoose (sounds risky)?

Changing the convDesc data type to Float32 so that the data type configuration becomes PSEUDO_HALF_CONFIG also works.

Is this the correct fix in cases when _WINOGRAD_NONFUSED is not applicable? The support for TRUE_HALF_CONFIG seems quite limited.

ToucheSir commented 1 year ago

This and https://github.com/FluxML/NNlib.jl/issues/505 are dupes. Depending on which one sees more activity, I'll close the other one.