FluxML / NNlib.jl

Neural Network primitives with multiple backends
Other
204 stars 121 forks source link

BatchNorm causes error during gradient computation #514

Open pxl-th opened 3 years ago

pxl-th commented 3 years ago

Hi, I've been using UNet-like architecture that accepts different encoders. And when passing EfficientNet as an encoder (that contains BatchNorm in the MBConv blocks), it crashes during the gradient computation only when on GPU.

Not 100% sure the issue is with this library, but here's the stacktrace:

ERROR: LoadError: MethodError: no method matching
  ∇batchnorm(::CuArray{Float32, 1}, ::CuArray{Float32, 1}, ::CuArray{Float32, 4}, ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}}, ::CuArray{Float32, 1}, ::CuArray{Float32, 1}, ::Float32; cache=nothing, alpha=1, beta=0, eps=0.001f0, training=true)
Closest candidates are:
  ∇batchnorm(::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::Any; cache, eps, alpha, beta, training) where T<:Union{Float32, Float64} at /home/pxl-th/.julia/packages/NNlibCUDA/Oc2CZ/src/cudnn/batchnorm.jl:81
  ∇batchnorm(::CuArray{T, N} where N, ::CuArray{T, N} where N, ::CuArray{T, 2}, ::CuArray{T, 2}, ::CuArray{T, N} where N, ::CuArray{T, N} where N, ::Any; cache, eps, alpha, beta, training) where T<:Union{Float32, Float64} at /home/pxl-th/.julia/packages/NNlibCUDA/Oc2CZ/src/cudnn/batchnorm.jl:71
Stacktrace:
  [1] (::Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Iterators.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CuArray{Float32, 1}, CuArray{Float32, 1}, CuArray{Float32, 4}, CuArray{Float32, 1}, CuArray{Float32, 1}, Float32})(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Flux.CUDAint ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:17
  [2] (::Flux.CUDAint.var"#793#back#4"{Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Iterators.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CuArray{Float32, 1}, CuArray{Float32, 1}, CuArray{Float32, 4}, CuArray{Float32, 1}, CuArray{Float32, 1}, Float32}})(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Flux.CUDAint ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:65
  [3] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:9 [inlined]
  [4] (::typeof(∂(λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [5] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:6 [inlined]
  [6] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:37 [inlined]
  [7] (::typeof(∂(applychain)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:37 [inlined]
  [9] (::typeof(∂(applychain)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/basic.jl:39 [inlined]
 [11] (::typeof(∂(λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [12] Pullback
    @ ./operators.jl:858 [inlined]
 [13] (::typeof(∂(|>)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/EfficientNet/NKvyu/src/mb.jl:109 [inlined]
 [15] (::typeof(∂(#_#7)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/EfficientNet/NKvyu/src/mb.jl:100 [inlined]
 [17] (::typeof(∂(Any##kw)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/.julia/packages/EfficientNet/NKvyu/src/model.jl:125 [inlined]
 [19] (::typeof(∂(λ)))(Δ::Vector{Union{Nothing, CuArray{Float32, 4}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [20] Pullback
    @ ~/projects/Segmentation.jl/src/Segmentation.jl:27 [inlined]
 [21] (::typeof(∂(λ)))(Δ::CuArray{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [22] Pullback
    @ ./operators.jl:858 [inlined]
 [23] (::typeof(∂(|>)))(Δ::CuArray{Float32, 4})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [24] Pullback
    @ ~/projects/Segmentation.jl/example/comma.jl:185 [inlined]
 [25] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [26] (::Zygote.var"#90#91"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:348
 [27] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:76
 [28] test_grads()
    @ Main ~/projects/Segmentation.jl/example/comma.jl:184
 [29] top-level scope
    @ ~/projects/Segmentation.jl/example/comma.jl:199
in expression starting at /home/pxl-th/projects/Segmentation.jl/example/comma.jl:199
ToucheSir commented 3 years ago

Can you reduce this down to a MWE? A MethodError shouldn't be too difficult to troubleshoot, we just need to trace the data coming into batchnorm.

pxl-th commented 3 years ago

Yeah, should've done that at the beginning :) I've figured that if both encoder blocks and decoder blocks in my code end with BatchNorm, I might as well construct everything using them.

MWE:

using Flux

function encode(encoder, x)
    features = typeof(x)[]
    for block in encoder
        x = block(x)
        push!(features, x)
    end
    features
end

function decode(decoder, features)
    features = features[end:-1:1]
    head, skips = features[1], features[2:end]

    x = head
    for (i, block) in enumerate(decoder)
        if i ≤ length(skips)
            x = cat(x, skips[i]; dims=3)
        end
        x = block(x)
    end
    x
end

function main()
    device = gpu
    x = randn(Float32, 10, 10, 3, 1) |> device

    encoder = Chain(BatchNorm(3), BatchNorm(3), BatchNorm(3)) |> device |> trainmode!
    decoder = Chain(BatchNorm(6), BatchNorm(9)) |> device |> trainmode!
    θ = params(encoder, decoder)

    gradient(θ) do
        features = encode(encoder, x)
        out = decode(decoder, features)
        sum(out)
    end
end
main()

Produces the same error:

ERROR: LoadError: MethodError: no method matching
  ∇batchnorm(::CUDA.CuArray{Float32, 1}, ::CUDA.CuArray{Float32, 1}, ::CUDA.CuArray{Float32, 4}, ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}}, ::CUDA.CuArray{Float32, 1}, ::CUDA.CuArray{Float32, 1}, ::Float32; cache=nothing, alpha=1, beta=0, eps=1.0f-5, training=true)
Closest candidates are:
  ∇batchnorm(::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::Any; cache, eps, alpha, beta, training) where T<:Union{Float32, Float64} at /home/pxl-th/.julia/packages/NNlibCUDA/Oc2CZ/src/cudnn/batchnorm.jl:81
  ∇batchnorm(::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, 2}, ::CUDA.CuArray{T, 2}, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::Any; cache, eps, alpha, beta, training) where T<:Union{Float32, Float64} at /home/pxl-th/.julia/packages/NNlibCUDA/Oc2CZ/src/cudnn/batchnorm.jl:71
Stacktrace:
  [1] (::Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Iterators.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 4}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 1}, Float32})(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Flux.CUDAint ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:17
  [2] (::Flux.CUDAint.var"#793#back#4"{Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Iterators.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 4}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 1}, Float32}})(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Flux.CUDAint ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:65
  [3] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:9 [inlined]
  [4] (::typeof(∂(λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [5] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:6 [inlined]
  [6] (::typeof(∂(λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [7] Pullback
    @ ~/projects/Segmentation.jl/src/mwe.jl:6 [inlined]
  [8] (::typeof(∂(encode)))(Δ::Vector{Union{Nothing, CUDA.CuArray{Float32, 4}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/projects/Segmentation.jl/src/mwe.jl:39 [inlined]
 [10] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [11] (::Zygote.var"#90#91"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:348
 [12] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:76
 [13] main()
    @ Main ~/projects/Segmentation.jl/src/mwe.jl:38
 [14] top-level scope
    @ ~/projects/Segmentation.jl/src/mwe.jl:45
in expression starting at /home/pxl-th/projects/Segmentation.jl/src/mwe.jl:45
DhairyaLGandhi commented 3 years ago

Likely due to FillArrays.Ones which may come from the push! adjoint.

I would get rid of the loops to construct the models etc and see how that performs.

As an aside - I'd love an EfficientNet implementation in Metalhead as a PR ;)

pxl-th commented 3 years ago

Indeed, it comes from the push!. However, if you replace identity activation with any other activation function (e.g. relu), the error disappears. But in MBConv the last BatchNorm has no activation.

Here's an even smaller MWE:

using Flux

function encode(encoder, x)
    features = typeof(x)[]
    for block in encoder
        x = block(x)
        push!(features, x)
    end
    features
end

function main()
    device = gpu
    x = randn(Float32, 10, 10, 3, 1) |> device
    encoder = Chain(BatchNorm(3, identity), BatchNorm(3, identity)) |> device |> trainmode!
    θ = params(encoder)
    gradient(θ) do
        sum(reduce(+, encode(encoder, x)))
    end
end
main()

Also I'm not sure how you would get rid of the loops, without unrolling them manually and without loss of generality. I want to be able to pass different encoders, where they can have different feature extraction depth. Having separate encoding and decoding stages makes things easier. Similar to how it is done in segmentation_models python package.

Maybe, for the GPU, we should "materialize" FillArrays.Ones into CuArray if we get one? Especially, since on CPU this is working fine and it would make sense to have the same support for GPU.

pxl-th commented 3 years ago

Similar thing happens if you replace BatchNorm with Conv. And the error disappears if you specify non-identity activation function.

MWE:

using Flux

function encode(encoder, x)
    features = typeof(x)[]
    for block in encoder
        x = block(x)
        push!(features, x)
    end
    features
end

function main()
    device = gpu
    x = randn(Float32, 10, 10, 3, 1) |> device

    encoder = Chain(
        Conv((3, 3), 3=>3, identity; pad=SamePad()),
        Conv((3, 3), 3=>3, identity; pad=SamePad()),
    ) |> device |> trainmode!
    θ = params(encoder)

    gradient(θ) do
        sum(reduce(+, encode(encoder, x)))
    end
end
main()

Error:

ERROR: LoadError: TaskFailedException

    nested task error: MethodError: no method matching
      gemm!(::Val{false}, ::Val{true}, ::Int64, ::Int64, ::Int64, ::Float32, ::Ptr{Float32}, ::CUDA.CuPtr{Float32}, ::Float32, ::Ptr{Float32})
    Closest candidates are:
      gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float32, ::Ptr{Float32}, ::Ptr{Float32}, ::Float32, ::Ptr{Float32}) at /home/pxl-th/.julia/packages/NNlib/YKZXm/src/gemm.jl:32
      gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float64, ::Ptr{Float64}, ::Ptr{Float64}, ::Float64, ::Ptr{Float64}) at /home/pxl-th/.julia/packages/NNlib/YKZXm/src/gemm.jl:32
      gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::ComplexF64, ::Ptr{ComplexF64}, ::Ptr{ComplexF64}, ::ComplexF64, ::Ptr{ComplexF64}) at /home/pxl-th/.julia/packages/NNlib/YKZXm/src/gemm.jl:32
      ...
    Stacktrace:
     [1] macro expansion
       @ ~/.julia/packages/NNlib/YKZXm/src/impl/conv_im2col.jl:156 [inlined]
     [2] (::NNlib.var"#752#threadsfor_fun#391"{Float32, Array{Float32, 3}, Float32, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, CUDA.CuArray{Float32, 5}, DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false}, Int64, Int64, Int64, UnitRange{Int64}})(onethread::Bool)
       @ NNlib ./threadingconstructs.jl:81
     [3] #invokelatest#2
       @ ./essentials.jl:708 [inlined]
     [4] invokelatest
       @ ./essentials.jl:706 [inlined]
     [5] macro expansion
       @ ./threadingconstructs.jl:86 [inlined]
     [6] ∇conv_data_im2col!(dx::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, dy::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CUDA.CuArray{Float32, 5}, cdims::DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false}; col::Array{Float32, 3}, alpha::Float32, beta::Float32)
       @ NNlib ~/.julia/packages/NNlib/YKZXm/src/impl/conv_im2col.jl:148
     [7] ∇conv_data_im2col!
       @ ~/.julia/packages/NNlib/YKZXm/src/impl/conv_im2col.jl:127 [inlined]
     [8] (::NNlib.var"#162#166"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false}, CUDA.CuArray{Float32, 5}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
       @ NNlib ./threadingconstructs.jl:169
Stacktrace:
  [1] sync_end(c::Channel{Any})
    @ Base ./task.jl:369
  [2] macro expansion
    @ ./task.jl:388 [inlined]
  [3] ∇conv_data!(out::Array{Float32, 5}, in1::Array{Float32, 5}, in2::CUDA.CuArray{Float32, 5}, cdims::DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/YKZXm/src/conv.jl:228
  [4] ∇conv_data!(out::Array{Float32, 5}, in1::Array{Float32, 5}, in2::CUDA.CuArray{Float32, 5}, cdims::DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false})
    @ NNlib ~/.julia/packages/NNlib/YKZXm/src/conv.jl:217
  [5] ∇conv_data!(y::Array{Float32, 4}, x::Array{Float32, 4}, w::CUDA.CuArray{Float32, 4}, cdims::DenseConvDims{2, (3, 3), 3, 3, 1, (1, 1), (1, 1, 1, 1), (1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/YKZXm/src/conv.jl:151
  [6] ∇conv_data!
    @ ~/.julia/packages/NNlib/YKZXm/src/conv.jl:151 [inlined]
  [7] #∇conv_data#89
    @ ~/.julia/packages/NNlib/YKZXm/src/conv.jl:104 [inlined]
  [8] ∇conv_data
    @ ~/.julia/packages/NNlib/YKZXm/src/conv.jl:101 [inlined]
  [9] #204
    @ ~/.julia/packages/NNlib/YKZXm/src/conv.jl:313 [inlined]
 [10] unthunk
    @ ~/.julia/packages/ChainRulesCore/BYuIz/src/differentials/thunks.jl:192 [inlined]
 [11] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/TaBlo/src/compiler/chainrules.jl:55 [inlined]
 [12] map
    @ ./tuple.jl:215 [inlined]
 [13] map
    @ ./tuple.jl:216 [inlined]
 [14] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/TaBlo/src/compiler/chainrules.jl:56 [inlined]
 [15] ZBack
    @ ~/.julia/packages/Zygote/TaBlo/src/compiler/chainrules.jl:91 [inlined]
 [16] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/conv.jl:165 [inlined]
 [17] (::typeof(∂(λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/projects/Segmentation.jl/src/mwe.jl:6 [inlined]
 [19] (::typeof(∂(encode)))(Δ::Vector{CUDA.CuArray{Float32, 4}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [20] Pullback
    @ ~/projects/Segmentation.jl/src/mwe.jl:21 [inlined]
 [21] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [22] (::Zygote.var"#90#91"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:348
 [23] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:76
 [24] main()
    @ Main ~/projects/Segmentation.jl/src/mwe.jl:20
 [25] top-level scope
    @ ~/projects/Segmentation.jl/src/mwe.jl:26
in expression starting at /home/pxl-th/projects/Segmentation.jl/src/mwe.jl:26
ToucheSir commented 3 years ago

This doesn't resolve the underlying issue, but you shouldn't have to use any mutation or explicit loops to write an EfficientNet-style architecture. Here's an equivalent version that works:

using Flux

encode(encoder, x) = map(f -> f(x), encoder)

function decode(decoder, features)
    nskips = length(features) - 1
    skip_blocks, rest_blocks = decoder[1:nskips], decoder[nskips:end]
    xs = foldl(zip(skip_blocks, features[nskips:-1:1]); init=features[end]) do acc, (f, x)
        f(cat(acc, x; dims=ndims(x) - 1)) # ndims(x) - 1 == 3 here, but is more general 
    end
    return rest_blocks(xs)
end

function main()
    device = gpu
    x = randn(Float32, 10, 10, 3, 1) |> device

    # encoders are not chained (they are run in parallel), so don't make them a Chain
    encoder = (BatchNorm(3), BatchNorm(3), BatchNorm(3)) |> device |> trainmode!
    decoder = Chain(BatchNorm(6), BatchNorm(9)) |> device |> trainmode!
    θ = params(encoder, decoder)

    gradient(θ) do
        features = encode(encoder, x)
        out = decode(decoder, features)
        sum(out)
    end
end
main()
pxl-th commented 3 years ago

Yes, I've changed the feature extraction part (encoder) for the model to use map instead of loops and now can take gradients. Although, having support for push!, in this case, would've been nice as well.

Here's fun animation of the training dynamics on a selected image from a small dataset, if anyone is curious :) output