Open pxl-th opened 3 years ago
Can you reduce this down to a MWE? A MethodError shouldn't be too difficult to troubleshoot, we just need to trace the data coming into batchnorm.
Yeah, should've done that at the beginning :) I've figured that if both encoder blocks and decoder blocks in my code end with BatchNorm, I might as well construct everything using them.
MWE:
using Flux
function encode(encoder, x)
features = typeof(x)[]
for block in encoder
x = block(x)
push!(features, x)
end
features
end
function decode(decoder, features)
features = features[end:-1:1]
head, skips = features[1], features[2:end]
x = head
for (i, block) in enumerate(decoder)
if i ≤ length(skips)
x = cat(x, skips[i]; dims=3)
end
x = block(x)
end
x
end
function main()
device = gpu
x = randn(Float32, 10, 10, 3, 1) |> device
encoder = Chain(BatchNorm(3), BatchNorm(3), BatchNorm(3)) |> device |> trainmode!
decoder = Chain(BatchNorm(6), BatchNorm(9)) |> device |> trainmode!
θ = params(encoder, decoder)
gradient(θ) do
features = encode(encoder, x)
out = decode(decoder, features)
sum(out)
end
end
main()
Produces the same error:
ERROR: LoadError: MethodError: no method matching
∇batchnorm(::CUDA.CuArray{Float32, 1}, ::CUDA.CuArray{Float32, 1}, ::CUDA.CuArray{Float32, 4}, ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}}, ::CUDA.CuArray{Float32, 1}, ::CUDA.CuArray{Float32, 1}, ::Float32; cache=nothing, alpha=1, beta=0, eps=1.0f-5, training=true)
Closest candidates are:
∇batchnorm(::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::Any; cache, eps, alpha, beta, training) where T<:Union{Float32, Float64} at /home/pxl-th/.julia/packages/NNlibCUDA/Oc2CZ/src/cudnn/batchnorm.jl:81
∇batchnorm(::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, 2}, ::CUDA.CuArray{T, 2}, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::Any; cache, eps, alpha, beta, training) where T<:Union{Float32, Float64} at /home/pxl-th/.julia/packages/NNlibCUDA/Oc2CZ/src/cudnn/batchnorm.jl:71
Stacktrace:
[1] (::Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Iterators.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 4}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 1}, Float32})(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Flux.CUDAint ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:17
[2] (::Flux.CUDAint.var"#793#back#4"{Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Iterators.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 4}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 1}, Float32}})(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Flux.CUDAint ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:65
[3] Pullback
@ ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:9 [inlined]
[4] (::typeof(∂(λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[5] Pullback
@ ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:6 [inlined]
[6] (::typeof(∂(λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[7] Pullback
@ ~/projects/Segmentation.jl/src/mwe.jl:6 [inlined]
[8] (::typeof(∂(encode)))(Δ::Vector{Union{Nothing, CUDA.CuArray{Float32, 4}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[9] Pullback
@ ~/projects/Segmentation.jl/src/mwe.jl:39 [inlined]
[10] (::typeof(∂(λ)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[11] (::Zygote.var"#90#91"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:348
[12] gradient(f::Function, args::Zygote.Params)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:76
[13] main()
@ Main ~/projects/Segmentation.jl/src/mwe.jl:38
[14] top-level scope
@ ~/projects/Segmentation.jl/src/mwe.jl:45
in expression starting at /home/pxl-th/projects/Segmentation.jl/src/mwe.jl:45
Likely due to FillArrays.Ones
which may come from the push!
adjoint.
I would get rid of the loops to construct the models etc and see how that performs.
As an aside - I'd love an EfficientNet implementation in Metalhead as a PR ;)
Indeed, it comes from the push!
.
However, if you replace identity
activation with any other activation function (e.g. relu
), the error disappears.
But in MBConv
the last BatchNorm
has no activation.
Here's an even smaller MWE:
using Flux
function encode(encoder, x)
features = typeof(x)[]
for block in encoder
x = block(x)
push!(features, x)
end
features
end
function main()
device = gpu
x = randn(Float32, 10, 10, 3, 1) |> device
encoder = Chain(BatchNorm(3, identity), BatchNorm(3, identity)) |> device |> trainmode!
θ = params(encoder)
gradient(θ) do
sum(reduce(+, encode(encoder, x)))
end
end
main()
Also I'm not sure how you would get rid of the loops, without unrolling them manually and without loss of generality. I want to be able to pass different encoders, where they can have different feature extraction depth. Having separate encoding and decoding stages makes things easier. Similar to how it is done in segmentation_models python package.
Maybe, for the GPU, we should "materialize" FillArrays.Ones
into CuArray
if we get one?
Especially, since on CPU this is working fine and it would make sense to have the same support for GPU.
Similar thing happens if you replace BatchNorm
with Conv
. And the error disappears if you specify non-identity activation function.
MWE:
using Flux
function encode(encoder, x)
features = typeof(x)[]
for block in encoder
x = block(x)
push!(features, x)
end
features
end
function main()
device = gpu
x = randn(Float32, 10, 10, 3, 1) |> device
encoder = Chain(
Conv((3, 3), 3=>3, identity; pad=SamePad()),
Conv((3, 3), 3=>3, identity; pad=SamePad()),
) |> device |> trainmode!
θ = params(encoder)
gradient(θ) do
sum(reduce(+, encode(encoder, x)))
end
end
main()
Error:
ERROR: LoadError: TaskFailedException
nested task error: MethodError: no method matching
gemm!(::Val{false}, ::Val{true}, ::Int64, ::Int64, ::Int64, ::Float32, ::Ptr{Float32}, ::CUDA.CuPtr{Float32}, ::Float32, ::Ptr{Float32})
Closest candidates are:
gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float32, ::Ptr{Float32}, ::Ptr{Float32}, ::Float32, ::Ptr{Float32}) at /home/pxl-th/.julia/packages/NNlib/YKZXm/src/gemm.jl:32
gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float64, ::Ptr{Float64}, ::Ptr{Float64}, ::Float64, ::Ptr{Float64}) at /home/pxl-th/.julia/packages/NNlib/YKZXm/src/gemm.jl:32
gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::ComplexF64, ::Ptr{ComplexF64}, ::Ptr{ComplexF64}, ::ComplexF64, ::Ptr{ComplexF64}) at /home/pxl-th/.julia/packages/NNlib/YKZXm/src/gemm.jl:32
...
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/NNlib/YKZXm/src/impl/conv_im2col.jl:156 [inlined]
[2] (::NNlib.var"#752#threadsfor_fun#391"{Float32, Array{Float32, 3}, Float32, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, CUDA.CuArray{Float32, 5}, DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false}, Int64, Int64, Int64, UnitRange{Int64}})(onethread::Bool)
@ NNlib ./threadingconstructs.jl:81
[3] #invokelatest#2
@ ./essentials.jl:708 [inlined]
[4] invokelatest
@ ./essentials.jl:706 [inlined]
[5] macro expansion
@ ./threadingconstructs.jl:86 [inlined]
[6] ∇conv_data_im2col!(dx::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, dy::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CUDA.CuArray{Float32, 5}, cdims::DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false}; col::Array{Float32, 3}, alpha::Float32, beta::Float32)
@ NNlib ~/.julia/packages/NNlib/YKZXm/src/impl/conv_im2col.jl:148
[7] ∇conv_data_im2col!
@ ~/.julia/packages/NNlib/YKZXm/src/impl/conv_im2col.jl:127 [inlined]
[8] (::NNlib.var"#162#166"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false}, CUDA.CuArray{Float32, 5}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
@ NNlib ./threadingconstructs.jl:169
Stacktrace:
[1] sync_end(c::Channel{Any})
@ Base ./task.jl:369
[2] macro expansion
@ ./task.jl:388 [inlined]
[3] ∇conv_data!(out::Array{Float32, 5}, in1::Array{Float32, 5}, in2::CUDA.CuArray{Float32, 5}, cdims::DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NNlib ~/.julia/packages/NNlib/YKZXm/src/conv.jl:228
[4] ∇conv_data!(out::Array{Float32, 5}, in1::Array{Float32, 5}, in2::CUDA.CuArray{Float32, 5}, cdims::DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false})
@ NNlib ~/.julia/packages/NNlib/YKZXm/src/conv.jl:217
[5] ∇conv_data!(y::Array{Float32, 4}, x::Array{Float32, 4}, w::CUDA.CuArray{Float32, 4}, cdims::DenseConvDims{2, (3, 3), 3, 3, 1, (1, 1), (1, 1, 1, 1), (1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NNlib ~/.julia/packages/NNlib/YKZXm/src/conv.jl:151
[6] ∇conv_data!
@ ~/.julia/packages/NNlib/YKZXm/src/conv.jl:151 [inlined]
[7] #∇conv_data#89
@ ~/.julia/packages/NNlib/YKZXm/src/conv.jl:104 [inlined]
[8] ∇conv_data
@ ~/.julia/packages/NNlib/YKZXm/src/conv.jl:101 [inlined]
[9] #204
@ ~/.julia/packages/NNlib/YKZXm/src/conv.jl:313 [inlined]
[10] unthunk
@ ~/.julia/packages/ChainRulesCore/BYuIz/src/differentials/thunks.jl:192 [inlined]
[11] wrap_chainrules_output
@ ~/.julia/packages/Zygote/TaBlo/src/compiler/chainrules.jl:55 [inlined]
[12] map
@ ./tuple.jl:215 [inlined]
[13] map
@ ./tuple.jl:216 [inlined]
[14] wrap_chainrules_output
@ ~/.julia/packages/Zygote/TaBlo/src/compiler/chainrules.jl:56 [inlined]
[15] ZBack
@ ~/.julia/packages/Zygote/TaBlo/src/compiler/chainrules.jl:91 [inlined]
[16] Pullback
@ ~/.julia/packages/Flux/Zz9RI/src/layers/conv.jl:165 [inlined]
[17] (::typeof(∂(λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[18] Pullback
@ ~/projects/Segmentation.jl/src/mwe.jl:6 [inlined]
[19] (::typeof(∂(encode)))(Δ::Vector{CUDA.CuArray{Float32, 4}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[20] Pullback
@ ~/projects/Segmentation.jl/src/mwe.jl:21 [inlined]
[21] (::typeof(∂(λ)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[22] (::Zygote.var"#90#91"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:348
[23] gradient(f::Function, args::Zygote.Params)
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:76
[24] main()
@ Main ~/projects/Segmentation.jl/src/mwe.jl:20
[25] top-level scope
@ ~/projects/Segmentation.jl/src/mwe.jl:26
in expression starting at /home/pxl-th/projects/Segmentation.jl/src/mwe.jl:26
This doesn't resolve the underlying issue, but you shouldn't have to use any mutation or explicit loops to write an EfficientNet-style architecture. Here's an equivalent version that works:
using Flux
encode(encoder, x) = map(f -> f(x), encoder)
function decode(decoder, features)
nskips = length(features) - 1
skip_blocks, rest_blocks = decoder[1:nskips], decoder[nskips:end]
xs = foldl(zip(skip_blocks, features[nskips:-1:1]); init=features[end]) do acc, (f, x)
f(cat(acc, x; dims=ndims(x) - 1)) # ndims(x) - 1 == 3 here, but is more general
end
return rest_blocks(xs)
end
function main()
device = gpu
x = randn(Float32, 10, 10, 3, 1) |> device
# encoders are not chained (they are run in parallel), so don't make them a Chain
encoder = (BatchNorm(3), BatchNorm(3), BatchNorm(3)) |> device |> trainmode!
decoder = Chain(BatchNorm(6), BatchNorm(9)) |> device |> trainmode!
θ = params(encoder, decoder)
gradient(θ) do
features = encode(encoder, x)
out = decode(decoder, features)
sum(out)
end
end
main()
Yes, I've changed the feature extraction part (encoder) for the model to use map
instead of loops and now can take gradients.
Although, having support for push!
, in this case, would've been nice as well.
Here's fun animation of the training dynamics on a selected image from a small dataset, if anyone is curious :)
Hi, I've been using UNet-like architecture that accepts different encoders. And when passing EfficientNet as an encoder (that contains BatchNorm in the MBConv blocks), it crashes during the gradient computation only when on GPU.
Not 100% sure the issue is with this library, but here's the stacktrace: