FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.48k stars 605 forks source link

Zygote error: UndefVarError: S not defined #1578

Closed jonathan-laurent closed 2 years ago

jonathan-laurent commented 3 years ago

I am trying to update AlphaZero.jl so that it works with Flux v0.12 but I am stuck on the following Zygote error:

ERROR: UndefVarError: S not defined

The bug happens on both Flux@0.12.1 and Flux#master.

Replication instructions

To replicate, you can run the following using Julia 1.6:

git clone --branch flux-0.12 https://github.com/jonathan-laurent/AlphaZero.jl.git
cd AlphaZero.jl
julia --project -e 'import Pkg; Pkg.instantiate()'
julia --project -e 'using AlphaZero; Scripts.test_grad_updates("connect-four")'

Full stacktrace

ERROR: UndefVarError: S not defined
Stacktrace:
  [1] (typeof(∂(λ)))(x::Tuple{typeof(∂(getproperty))})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface.jl:19
  [2] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
  [3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
  [4] _pullback
    @ ./promotion.jl:87 [inlined]
  [5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{CUDA.CuArray{Float32, 4}}, ::Type{CUDA.CuArray{Float32, 2}})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
  [6] _pullback
    @ ./promotion.jl:149 [inlined]
  [7] _pullback
    @ ./array.jl:748 [inlined]
  [8] _pullback(::Zygote.Context, ::typeof(Base.push_widen), ::Vector{CUDA.CuArray{Float32, 4}}, ::CUDA.CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
  [9] _pullback
    @ ./array.jl:767 [inlined]
 [10] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{CUDA.CuArray{Float32, 4}}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, AlphaZero.FluxLib.var"#27#28"}}, ::Tuple{Int64, Base.Generator{Vector{CUDA.CuArray{Float32, 4}}, typeof(identity)}, Int64})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [11] _pullback
    @ ./array.jl:743 [inlined]
 [12] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Any}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, AlphaZero.FluxLib.var"#27#28"}})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [13] _pullback
    @ ./array.jl:652 [inlined]
 [14] _pullback(::Zygote.Context, ::typeof(Base._collect), ::UnitRange{Int64}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, AlphaZero.FluxLib.var"#27#28"}}, ::Base.EltypeUnknown, ::Base.SizeUnknown)
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [15] _pullback
    @ ./array.jl:602 [inlined]
 [16] _pullback(ctx::Zygote.Context, f::typeof(collect), args::Base.Iterators.Flatten{Base.Generator{Vector{Any}, AlphaZero.FluxLib.var"#27#28"}})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [17] _pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:117 [inlined]
 [18] _pullback(ctx::Zygote.Context, f::typeof(AlphaZero.Network.regularized_params), args::ResNet)
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [19] _pullback
    @ ~/AlphaZero.jl/src/learning.jl:75 [inlined]
 [20] _pullback(::Zygote.Context, ::typeof(AlphaZero.losses), ::ResNet, ::LearningParams, ::Float32, ::Float32, ::Tuple{CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 4}, CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 2}})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [21] _pullback
    @ ~/AlphaZero.jl/src/learning.jl:122 [inlined]
 [22] _pullback(::Zygote.Context, ::AlphaZero.var"#L#110"{AlphaZero.Trainer}, ::CUDA.CuArray{Float32, 2}, ::CUDA.CuArray{Float32, 4}, ::CUDA.CuArray{Float32, 2}, ::CUDA.CuArray{Float32, 2}, ::CUDA.CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [23] adjoint
    @ ~/.julia/packages/Zygote/CgsVi/src/lib/lib.jl:188 [inlined]
 [24] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [25] _pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:82 [inlined]
 [26] _pullback(::Zygote.Context, ::AlphaZero.FluxLib.var"#1#2"{AlphaZero.var"#L#110"{AlphaZero.Trainer}, Tuple{CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 4}, CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 2}}})
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [27] pullback(f::Function, ps::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/CgsVi/src/compiler/interface.jl:247
 [28] lossgrads(f::Function, args::Zygote.Params)
    @ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:72
 [29] train!(callback::AlphaZero.var"#109#111"{Vector{Float32}}, nn::ResNet, opt::Adam, loss::Function, data::Base.Iterators.Take{Base.Iterators.Stateful{Base.Iterators.Flatten{Base.Generator{Base.Iterators.Repeated{Nothing}, AlphaZero.Util.var"#12#13"{AlphaZero.var"#106#108"{ResNet}, Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}, Int64, Bool}}}, Tuple{NTuple{5, Any}, Tuple{Nothing, Base.Generator{Vector{Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, AlphaZero.Util.var"#9#11"{AlphaZero.var"#106#108"{ResNet}}}, Int64}}}}, n::Int64)
    @ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:81
 [30] batch_updates!(tr::AlphaZero.Trainer, n::Int64)
    @ AlphaZero ~/AlphaZero.jl/src/learning.jl:125
 [31] macro expansion
    @ ./timing.jl:356 [inlined]
 [32] learning_step!(env::Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}, handler::Session{Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}})
    @ AlphaZero ~/AlphaZero.jl/src/training.jl:224
 [33] test_grad_updates(exp::Experiment; num_games::Int64)
    @ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:17
 [34] test_grad_updates
    @ ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:10 [inlined]
 [35] #test_grad_updates#21
    @ ~/AlphaZero.jl/src/scripts/scripts.jl:57 [inlined]
 [36] test_grad_updates(s::String)
    @ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/scripts.jl:57
 [37] top-level scope
    @ REPL[2]:1
DhairyaLGandhi commented 3 years ago

Interesting, thanks for pointing this out. Next step might be to check with older Zygote releases to see where this came up. I suspected the push related changes caused it to happen, but I didn't find any stray Ss in there.

darsnack commented 3 years ago

The source of the issue is this line.

DhairyaLGandhi commented 3 years ago
julia> d = Chain(Dense(3,3), Dense(3,3))
Chain(Dense(3, 3), Dense(3, 3))

julia> x = rand(Float32, 3,4);

julia> gs = gradient(Flux.params(d)) do
         ds = Flux.modules(d)
         sum(l(x) for l in ds) |> sum
       end
Grads(...)
darsnack commented 3 years ago

Flux.modules has @nograd defined. I think the generator expression might the problem.

DhairyaLGandhi commented 3 years ago

There's a generator in the above expression as well.

darsnack commented 3 years ago

julia> regularized_params_(l::Flux.Dense) = [l.W]
regularized_params_ (generic function with 1 method)

julia> regularized_params_(l) = []
regularized_params_ (generic function with 2 methods)

julia> regularized_params_(l::Flux.Conv) = [l.weight]
regularized_params_ (generic function with 3 methods)

julia> d = Chain(Conv((3, 3), 1 => 2), Dense(3,3), Dense(3,3))
Chain(Conv((3, 3), 1=>2), Dense(3, 3), Dense(3, 3))

julia> gs = gradient(Flux.params(d)) do
           ws = [w for l in Flux.modules(d) for w in regularized_params_(l)]
           sum(sum(w) for w in ws)
       end
ERROR: UndefVarError: S not defined
Stacktrace:
  [1] (typeof(∂(λ)))(x::Tuple{typeof(∂(getproperty))})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:19
  [2] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
  [3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [4] _pullback
    @ ./promotion.jl:87 [inlined]
  [5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{Array{Float32, 4}}, ::Type{Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [6] _pullback
    @ ./promotion.jl:149 [inlined]
  [7] _pullback
    @ ./array.jl:748 [inlined]
  [8] _pullback(::Zygote.Context, ::typeof(Base.push_widen), ::Vector{Array{Float32, 4}}, ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [9] _pullback
    @ ./array.jl:767 [inlined]
 [10] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Array{Float32, 4}}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}}, ::Tuple{Int64, Base.Generator{Vector{Array{Float32, 4}}, typeof(identity)}, Int64})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [11] _pullback
    @ ./array.jl:743 [inlined]
 [12] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Any}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [13] _pullback
    @ ./array.jl:652 [inlined]
 [14] _pullback(::Zygote.Context, ::typeof(Base._collect), ::UnitRange{Int64}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}}, ::Base.EltypeUnknown, ::Base.SizeUnknown)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [15] _pullback
    @ ./array.jl:602 [inlined]
 [16] _pullback(ctx::Zygote.Context, f::typeof(collect), args::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [17] _pullback
    @ ./REPL[474]:2 [inlined]
 [18] _pullback(::Zygote.Context, ::var"#39#41")
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [19] pullback(f::Function, ps::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:247
 [20] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:58
 [21] top-level scope
    @ REPL[474]:1
darsnack commented 3 years ago

It's more subtle too. You need a Conv + Dense:

julia> d = Chain(Dense(3, 3), Dense(3, 3))
Chain(Dense(3, 3), Dense(3, 3))

julia> gs = gradient(Flux.params(d)) do
           ws = [w for l in Flux.modules(d) for w in regularized_params_(l)]
           sum(sum(w) for w in ws)
       end
ERROR: Compiling Tuple{Base.var"##depwarn#864", Bool, typeof(Base.depwarn), String, Symbol}: try/catch is not supported.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:121
  [3] #Primal#20
    @ ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:202 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:315
  [5] _lookup_grad(T::Type)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/emit.jl:101
  [6] #s2993#1177
    @ ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:34 [inlined]
  [7] var"#s2993#1177"(T::Any, j::Any, Δ::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:571
  [9] Pullback
    @ ./deprecated.jl:80 [inlined]
 [10] (::typeof(∂(depwarn)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/Flux/qp1gc/src/deprecations.jl:13 [inlined]
 [12] (::typeof(∂(getproperty)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
 [14] Pullback
    @ ./REPL[457]:1 [inlined]
 [15] (::typeof(∂(regularized_params_)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [16] Pullback
    @ ./none:0 [inlined]
 [17] (::typeof(∂(#44)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [18] Pullback
    @ ./generator.jl:47 [inlined]
 [19] (::typeof(∂(iterate)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [20] Pullback
    @ ./iterators.jl:1093 [inlined]
 [21] (::typeof(∂(iterate)))(Δ::Tuple{Nothing, Tuple{Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [22] Pullback
    @ ./array.jl:761 [inlined]
 [23] (::typeof(∂(grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [24] Pullback
    @ ./array.jl:743 [inlined]
 [25] (::typeof(∂(grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [26] Pullback
    @ ./array.jl:652 [inlined]
 [27] (::typeof(∂(_collect)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [28] Pullback
    @ ./array.jl:602 [inlined]
 [29] (::typeof(∂(collect)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [30] Pullback
    @ ./REPL[476]:2 [inlined]
 [31] (::typeof(∂(#43)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [32] (::Zygote.var"#69#70"{Params, typeof(∂(#43)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:252
 [33] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:59
 [34] top-level scope
    @ REPL[476]:1

I think if you look at the stack trace, you'll see a call to typejoin because the type of the Conv and Dense weights are not the same type.

Even smaller reproducer without Flux.modules:

julia> gs = gradient(Flux.params(d)) do
           ws = [w for l in d for w in regularized_params_(l)]
           sum(sum(w) for w in ws)
       end
ERROR: UndefVarError: S not defined
Stacktrace:
  [1] (typeof(∂(λ)))(x::Tuple{typeof(∂(getproperty))})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:19
  [2] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
  [3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [4] _pullback
    @ ./promotion.jl:87 [inlined]
  [5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{Array{Float32, 4}}, ::Type{Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [6] _pullback
    @ ./promotion.jl:149 [inlined]
  [7] _pullback
    @ ./array.jl:748 [inlined]
  [8] _pullback(::Zygote.Context, ::typeof(Base.push_widen), ::Vector{Array{Float32, 4}}, ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [9] _pullback
    @ ./array.jl:767 [inlined]
 [10] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Array{Float32, 4}}, ::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}}, ::Tuple{Int64, Base.Generator{Vector{Array{Float32, 4}}, typeof(identity)}, Int64})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [11] _pullback
    @ ./array.jl:743 [inlined]
 [12] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Any}, ::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [13] _pullback
    @ ./array.jl:652 [inlined]
 [14] _pullback(::Zygote.Context, ::typeof(Base._collect), ::UnitRange{Int64}, ::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}}, ::Base.EltypeUnknown, ::Base.SizeUnknown)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [15] _pullback
    @ ./array.jl:602 [inlined]
 [16] _pullback(ctx::Zygote.Context, f::typeof(collect), args::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [17] _pullback
    @ ./REPL[484]:2 [inlined]
 [18] _pullback(::Zygote.Context, ::var"#63#65")
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [19] pullback(f::Function, ps::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:247
 [20] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:58
 [21] top-level scope
    @ REPL[484]:1
darsnack commented 3 years ago

I found that I needed different types for the elements in the generator, and I need nested generators (to force the call to collect which in turn would trigger the typejoin). For example, the following gets a different error:

julia> gs = gradient(Flux.params(d)) do
           sum(sum(w) for m in d for w in regularized_params_(m))
       end
ERROR: Compiling Tuple{Base.var"##depwarn#864", Bool, typeof(Base.depwarn), String, Symbol}: try/catch is not supported.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:121
  [3] #Primal#20
    @ ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:202 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:315
  [5] _lookup_grad(T::Type)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/emit.jl:101
  [6] #s2993#1177
    @ ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:34 [inlined]
  [7] var"#s2993#1177"(T::Any, j::Any, Δ::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:571
  [9] Pullback
    @ ./deprecated.jl:80 [inlined]
 [10] (::typeof(∂(depwarn)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/Flux/qp1gc/src/deprecations.jl:13 [inlined]
 [12] (::typeof(∂(getproperty)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
 [14] Pullback
    @ ./REPL[457]:1 [inlined]
 [15] (::typeof(∂(regularized_params_)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [16] Pullback
    @ ./none:0 [inlined]
 [17] (::typeof(∂(#60)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [18] Pullback
    @ ./reduce.jl:93 [inlined]
 [19] (::typeof(∂(Base.MappingRF{var"#60#62", Base.FlatteningRF{Base.BottomRF{typeof(Base.add_sum)}}}(var"#60#62"(), Base.FlatteningRF{Base.BottomRF{typeof(Base.add_sum)}}(Base.BottomRF{typeof(Base.add_sum)}(Base.add_sum))))))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [20] Pullback
    @ ./reduce.jl:62 [inlined]
 [21] (::typeof(∂(_foldl_impl)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [22] Pullback
    @ ./reduce.jl:48 [inlined]
 [23] (::typeof(∂(foldl_impl)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [24] Pullback
    @ ./reduce.jl:44 [inlined]
 [25] (::typeof(∂(mapfoldl_impl)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [26] Pullback (repeats 2 times)
    @ ./reduce.jl:160 [inlined]
 [27] (::typeof(∂(mapfoldl)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [28] Pullback
    @ ./reduce.jl:287 [inlined]
 [29] (::typeof(∂(#mapreduce#218)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [30] Pullback
    @ ./reduce.jl:287 [inlined]
 [31] (::typeof(∂(mapreduce)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [32] Pullback
    @ ./reduce.jl:501 [inlined]
 [33] (::typeof(∂(#sum#221)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [34] Pullback
    @ ./reduce.jl:501 [inlined]
 [35] (::typeof(∂(sum)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [36] Pullback
    @ ./reduce.jl:528 [inlined]
 [37] (::typeof(∂(#sum#222)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [38] Pullback
    @ ./reduce.jl:528 [inlined]
 [39] (::typeof(∂(sum)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [40] Pullback
    @ ./REPL[483]:2 [inlined]
 [41] (::typeof(∂(#59)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [42] (::Zygote.var"#69#70"{Params, typeof(∂(#59)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:252
 [43] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:59
 [44] top-level scope
    @ REPL[483]:1
darsnack commented 3 years ago

Okay finally got a MWE:

julia> gs = gradient(params(d)) do
           x = typejoin(Array{Float32, 4}, Array{Float32, 2})
           return 1
       end
ERROR: UndefVarError: S not defined
Stacktrace:
  [1] (typeof(∂(λ)))(x::Tuple{typeof(∂(getproperty))})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:19
  [2] _pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
  [3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [4] _pullback
    @ ./promotion.jl:87 [inlined]
  [5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{Array{Float32, 4}}, ::Type{Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [6] _pullback
    @ ./REPL[486]:2 [inlined]
  [7] _pullback(::Zygote.Context, ::var"#67#68")
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [8] pullback(f::Function, ps::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:247
  [9] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:58
 [10] top-level scope
    @ REPL[486]:1
DhairyaLGandhi commented 3 years ago

Just opened https://github.com/FluxML/Zygote.jl/pull/947

Looks like we came to the same conclusion lol.

@darsnack added you as a co-author (hope you don't mind)

jonathan-laurent commented 3 years ago

I updated the Manifest to use https://github.com/FluxML/Zygote.jl/pull/947 but now I am seeing a different error. (I updated the flux-0.12 branch so you can still use the replication instructions above).

ERROR: Compiling Tuple{Base.var"##depwarn#868", Bool, typeof(Base.depwarn), String, Symbol}: try/catch is not supported.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/reverse.jl:121
  [3] #Primal#20
    @ ~/.julia/packages/Zygote/iSZne/src/compiler/reverse.jl:202 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/reverse.jl:315
  [5] _lookup_grad(T::Type)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/emit.jl:101
  [6] #s2996#1179
    @ ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:34 [inlined]
  [7] var"#s2996#1179"(T::Any, j::Any, Δ::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:571
  [9] Pullback
    @ ./deprecated.jl:80 [inlined]
 [10] (::typeof(∂(depwarn)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/Flux/qp1gc/src/deprecations.jl:13 [inlined]
 [12] (::typeof(∂(getproperty)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
 [14] Pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:113 [inlined]
 [15] (::typeof(∂(regularized_params_)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [16] Pullback
    @ ./none:0 [inlined]
 [17] (::typeof(∂(#5)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [18] Pullback
    @ ./generator.jl:47 [inlined]
 [19] (::typeof(∂(iterate)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [20] Pullback
    @ ./iterators.jl:1097 [inlined]
 [21] (::typeof(∂(iterate)))(Δ::Tuple{Nothing, Tuple{Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [22] Pullback
    @ ./array.jl:770 [inlined]
 [23] (::typeof(∂(grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [24] Pullback
    @ ./array.jl:768 [inlined]
 [25] (::typeof(∂(grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [26] Pullback
    @ ./array.jl:743 [inlined]
 [27] (::typeof(∂(grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [28] Pullback
    @ ./array.jl:652 [inlined]
 [29] (::typeof(∂(_collect)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [30] Pullback
    @ ./array.jl:602 [inlined]
 [31] (::typeof(∂(collect)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [32] Pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:117 [inlined]
 [33] (::typeof(∂(regularized_params)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [34] Pullback
    @ ~/AlphaZero.jl/src/learning.jl:75 [inlined]
 [35] (::typeof(∂(losses)))(Δ::Tuple{Float32, Nothing, Nothing, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [36] Pullback
    @ ~/AlphaZero.jl/src/learning.jl:122 [inlined]
 [37] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [38] (::Zygote.var"#178#179"{Tuple{NTuple{5, Nothing}}, typeof(∂(λ))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/lib/lib.jl:194
 [39] #1686#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [40] Pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:82 [inlined]
 [41] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
 [42] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface.jl:252
 [43] lossgrads(f::Function, args::Zygote.Params)
    @ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:73
 [44] train!(callback::AlphaZero.var"#109#111"{Vector{Float32}}, nn::ResNet, opt::Adam, loss::Function, data::Base.Iterators.Take{Base.Iterators.Stateful{Base.Iterators.Flatten{Base.Generator{Base.Iterators.Repeated{Nothing}, AlphaZero.Util.var"#12#13"{AlphaZero.var"#106#108"{ResNet}, Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}, Int64, Bool}}}, Tuple{NTuple{5, Any}, Tuple{Nothing, Base.Generator{Vector{Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, AlphaZero.Util.var"#9#11"{AlphaZero.var"#106#108"{ResNet}}}, Int64}}}}, n::Int64)
    @ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:81
 [45] batch_updates!(tr::AlphaZero.Trainer, n::Int64)
    @ AlphaZero ~/AlphaZero.jl/src/learning.jl:125
 [46] macro expansion
    @ ./timing.jl:356 [inlined]
 [47] learning_step!(env::Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}, handler::Session{Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}})
    @ AlphaZero ~/AlphaZero.jl/src/training.jl:224
 [48] test_grad_updates(exp::Experiment; num_games::Int64)
    @ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:17
 [49] test_grad_updates
    @ ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:10 [inlined]
 [50] #test_grad_updates#21
    @ ~/AlphaZero.jl/src/scripts/scripts.jl:57 [inlined]
 [51] test_grad_updates(s::String)
    @ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/scripts.jl:57
DhairyaLGandhi commented 3 years ago

I addressed that in https://github.com/JuliaDiff/ChainRules.jl/pull/398.

DhairyaLGandhi commented 3 years ago

Should be closed in https://github.com/JuliaDiff/ChainRules.jl/pull/398

jonathan-laurent commented 3 years ago

Now it works if I define this:

function Network.regularized_params(net::FluxNetwork)
  return (w for l in Flux.modules(net) for w in regularized_params_(l))
end

Just so you know, I still have an error if I use this definition instead (returning an array instead of a generator):

function Network.regularized_params(net::FluxNetwork)
  return [w for l in Flux.modules(net) for w in regularized_params_(l)]
end

I am not sure whether or not this should be considered a bug. If not, how should I modify the second example to make it work? Would I need to use a @nograd macro somewhere?

ERROR: Mutating arrays is not supported
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#405#406")(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/lib/array.jl:61
  [3] (::Zygote.var"#2266#back#407"{Zygote.var"#405#406"})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./array.jl:977 [inlined]
  [5] (::typeof(∂(append!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./array.jl:753 [inlined]
  [7] (::typeof(∂(push_widen)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
  [8] Pullback
    @ ./array.jl:767 [inlined]
  [9] (::typeof(∂(grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./array.jl:743 [inlined]
 [11] (::typeof(∂(grow_to!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [12] Pullback
    @ ./array.jl:652 [inlined]
 [13] (::typeof(∂(_collect)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [14] Pullback
    @ ./array.jl:602 [inlined]
 [15] (::typeof(∂(collect)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:117 [inlined]
 [17] (::typeof(∂(regularized_params)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/AlphaZero.jl/src/learning.jl:75 [inlined]
 [19] (::typeof(∂(losses)))(Δ::Tuple{Float32, Nothing, Nothing, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [20] Pullback
    @ ~/AlphaZero.jl/src/learning.jl:122 [inlined]
 [21] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [22] (::Zygote.var"#178#179"{Tuple{NTuple{5, Nothing}}, typeof(∂(λ))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/lib/lib.jl:194
 [23] #1686#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [24] Pullback
    @ ~/AlphaZero.jl/src/networks/flux.jl:82 [inlined]
 [25] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
 [26] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:252
 [27] lossgrads(f::Function, args::Zygote.Params)
    @ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:73
 [28] train!(callback::AlphaZero.var"#109#111"{Vector{Float32}}, nn::ResNet, opt::Adam, loss::Function, data::Base.Iterators.Take{Base.Iterators.Stateful{Base.Iterators.Flatten{Base.Generator{Base.Iterators.Repeated{Nothing}, AlphaZero.Util.var"#12#13"{AlphaZero.var"#106#108"{ResNet}, Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}, Int64, Bool}}}, Tuple{NTuple{5, Any}, Tuple{Nothing, Base.Generator{Vector{Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, AlphaZero.Util.var"#9#11"{AlphaZero.var"#106#108"{ResNet}}}, Int64}}}}, n::Int64)
    @ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:81
 [29] batch_updates!(tr::AlphaZero.Trainer, n::Int64)
    @ AlphaZero ~/AlphaZero.jl/src/learning.jl:125
 [30] macro expansion
    @ ./timing.jl:356 [inlined]
 [31] learning_step!(env::Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}, handler::Session{Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}})
    @ AlphaZero ~/AlphaZero.jl/src/training.jl:224
 [32] test_grad_updates(exp::Experiment; num_games::Int64)
    @ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:17
 [33] test_grad_updates
    @ ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:10 [inlined]
 [34] #test_grad_updates#21
    @ ~/AlphaZero.jl/src/scripts/scripts.jl:57 [inlined]
 [35] test_grad_updates(s::String)
    @ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/scripts.jl:57
 [36] top-level scope
    @ REPL[3]:1
DhairyaLGandhi commented 3 years ago

Interesting is this with the chain rules update?

jonathan-laurent commented 3 years ago

Yes, this is with the chain rule update.

DhairyaLGandhi commented 3 years ago

So this is in fact doing array mutation internally, so I'm comfortable to say that Zygote is indeed showing the expected behavior, however, nested generators are something that we had working before. For now I'd return a literal generator and look into the code gen separately

jonathan-laurent commented 3 years ago

Interesting update:

This issue is about making sure that code like this works with Flux:

gs = gradient(Flux.params(d)) do
    ds = Flux.modules(d)
    sum(l(x) for l in ds) |> sum
end

However, the code above appears to be very slow. Indeed, I achieved a 3x speedup (and a tenfold reduction in the number of GPU allocs) in the backprop phase of AlphaZero.jl by essentially replacing the code above by:

ds = Flux.modules(d)
gs = gradient(Flux.params(d)) do
    sum(l(x) for l in ds) |> sum
end

For more details, see the exact commit: : https://github.com/jonathan-laurent/AlphaZero.jl/commit/b8bac93c9048e84ebd1aee6e51fad5b58dae357f

Here is the output of CUDA.@time before and after I made the change in AlphaZero.jl:

BEFORE:
12.463815 seconds (44.18 M CPU allocations: 1.117 GiB, 4.67% gc time) (401.32 k GPU allocations: 314.618 GiB, 18.44% gc time of which 34.41% spent allocating)

AFTER:
4.198313 seconds (11.60 M CPU allocations: 451.503 MiB, 7.84% gc time) (45.81 k GPU allocations: 314.617 GiB, 8.46% gc time of which 38.23% spent allocating)

Although it is understandable for the second version to be faster, I definitely did not expect such a gap. Note that this is possibly the result of a recent regression as I think I would have noticed this before otherwise.

darsnack commented 3 years ago

Note that this is possibly the result of a recent regression as I think I would have noticed this before otherwise.

How recent? Flux.modules is new (as of v0.12) and uses a very different implementation (based on Functors.jl) than the for-loop it replaces in AlphaZero.jl. (see edit) So, I am not surprised that they don't perform the same, but it's good to know that it is so slow. We definitely want to make the first code example performant, cause that's how I'd expect Flux.modules to be used most of the time.

Unless by recent you mean within the last couple weeks (i.e. comparing Flux.modules vs. Flux.modules w/ some updates). In which case, Zygote would be the first place to look.

Maybe a good cross-check would be to replace Flux.modules with some other generator?


EDIT:

Actually they are very similar.

jonathan-laurent commented 3 years ago

Before Flux v0.12, I was relying on my own implementation of modules.

See this commit: https://github.com/jonathan-laurent/AlphaZero.jl/commit/7bbb2cb13cc09a4cd2bed905acafc1e45097a585#diff-837cea9edf0b0d7507b695b63c300d02574ec50c6b3227ee6fe4701544b3e97a

ToucheSir commented 3 years ago

Could it be that fcollect uses a vector as a cache and that's slower in AD? Not sure what else could be a factor.

DhairyaLGandhi commented 3 years ago

It has been @nogradd already.

ToucheSir commented 3 years ago

Ah right, disregard that then.

I played around with the examples above on CPU and GPU. Not sure why using modules inside the gradient context has such a crushing performance impact on CPU-side allocations given that modules is @nograd.

d = Chain((Dense(32, 32) for i in 1:5)...)
x = rand(Float32, 32, 8)

f1(m, x) = gradient(Flux.params(m)) do
    ds = Flux.modules(m)
    sum(l(x) for l in ds) |> sum
end

function f2(m, x)
    ds = Flux.modules(m)
    gradient(Flux.params(m)) do
        sum(l(x) for l in ds) |> sum
    end
end

julia> @timev f1(d, x)
 24.901731 seconds (63.67 M allocations: 3.610 GiB, 4.22% gc time)
elapsed time (ns): 24901730841
gc time (ns):      1051987239
bytes allocated:   3875757116
pool allocs:       63656612
non-pool GC allocs:15360
malloc() calls:    88
realloc() calls:   15
GC pauses:         53
full collections:  2

julia> @timev f2(d, x)
  1.086023 seconds (2.15 M allocations: 119.429 MiB, 2.75% gc time, 98.94% compilation time)
elapsed time (ns): 1086023325
gc time (ns):      29910632
bytes allocated:   125229937
pool allocs:       2152312
non-pool GC allocs:161
GC pauses:         3

julia> @btime f2($d, $x)
  415.283 μs (2140 allocations: 210.67 KiB)

julia> @btime f1($d, $x)
  414.055 μs (2146 allocations: 214.48 KiB)
mcabbott commented 2 years ago

FWIW, running the above example:

julia> @timev f1(d, x)
 12.528920 seconds (61.88 M allocations: 3.262 GiB, 5.23% gc time, 99.97% compilation time)
elapsed time (ns): 12528919792
gc time (ns):      655876707
bytes allocated:   3502653566
pool allocs:       61852547
non-pool GC allocs:23328
malloc() calls:    103
realloc() calls:   20
GC pauses:         42
full collections:  1
Grads(...)

julia> @timev f2(d, x)
  0.456313 seconds (2.03 M allocations: 103.419 MiB, 5.52% gc time, 99.67% compilation time)
elapsed time (ns): 456313042
gc time (ns):      25197167
bytes allocated:   108442549
pool allocs:       2025083
non-pool GC allocs:271
GC pauses:         2
Grads(...)

julia> @btime f2($d, $x)
  min 171.500 μs, mean 209.663 μs (1631 allocations, 309.03 KiB. GC mean 11.82%)
Grads(...)

julia> @btime f1($d, $x)
  min 173.167 μs, mean 209.863 μs (1636 allocations, 312.83 KiB. GC mean 11.27%)
Grads(...)

And after re-starting, running them in the opposite order:

julia> @timev f2(d, x)
 12.708308 seconds (61.94 M allocations: 3.266 GiB, 5.74% gc time, 99.98% compilation time)
elapsed time (ns): 12708308041
gc time (ns):      729664123
bytes allocated:   3506520558
pool allocs:       61916466
non-pool GC allocs:23354
malloc() calls:    104
realloc() calls:   20
GC pauses:         39
full collections:  2
Grads(...)

julia> @timev f1(d, x)
  0.447610 seconds (1.98 M allocations: 100.910 MiB, 2.59% gc time, 99.69% compilation time)
elapsed time (ns): 447609750
gc time (ns):      11576334
bytes allocated:   105811871
pool allocs:       1980090
non-pool GC allocs:250
GC pauses:         1
Grads(...)

(@v1.8) pkg> st Flux Zygote
Status `~/.julia/environments/v1.8/Project.toml`
  [587475ba] Flux v0.12.8
  [e88e6eb3] Zygote v0.6.32

So there's 12s of startup time on whichever is first, but no obvious effect of where Flux.modules(m) is called.

ToucheSir commented 2 years ago

With the benefit of hindsight, the first run is basically measuring any AD compilation overhead, while the second can benefit from caching. It also looks like the original issue is resolved, so worth opening a new one if anything persists.