Closed jonathan-laurent closed 2 years ago
Interesting, thanks for pointing this out. Next step might be to check with older Zygote releases to see where this came up. I suspected the push
related changes caused it to happen, but I didn't find any stray S
s in there.
julia> d = Chain(Dense(3,3), Dense(3,3))
Chain(Dense(3, 3), Dense(3, 3))
julia> x = rand(Float32, 3,4);
julia> gs = gradient(Flux.params(d)) do
ds = Flux.modules(d)
sum(l(x) for l in ds) |> sum
end
Grads(...)
Flux.modules
has @nograd
defined. I think the generator expression might the problem.
There's a generator in the above expression as well.
julia> regularized_params_(l::Flux.Dense) = [l.W]
regularized_params_ (generic function with 1 method)
julia> regularized_params_(l) = []
regularized_params_ (generic function with 2 methods)
julia> regularized_params_(l::Flux.Conv) = [l.weight]
regularized_params_ (generic function with 3 methods)
julia> d = Chain(Conv((3, 3), 1 => 2), Dense(3,3), Dense(3,3))
Chain(Conv((3, 3), 1=>2), Dense(3, 3), Dense(3, 3))
julia> gs = gradient(Flux.params(d)) do
ws = [w for l in Flux.modules(d) for w in regularized_params_(l)]
sum(sum(w) for w in ws)
end
ERROR: UndefVarError: S not defined
Stacktrace:
[1] (typeof(∂(λ)))(x::Tuple{typeof(∂(getproperty))})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:19
[2] _pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
[3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[4] _pullback
@ ./promotion.jl:87 [inlined]
[5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{Array{Float32, 4}}, ::Type{Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[6] _pullback
@ ./promotion.jl:149 [inlined]
[7] _pullback
@ ./array.jl:748 [inlined]
[8] _pullback(::Zygote.Context, ::typeof(Base.push_widen), ::Vector{Array{Float32, 4}}, ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[9] _pullback
@ ./array.jl:767 [inlined]
[10] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Array{Float32, 4}}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}}, ::Tuple{Int64, Base.Generator{Vector{Array{Float32, 4}}, typeof(identity)}, Int64})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[11] _pullback
@ ./array.jl:743 [inlined]
[12] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Any}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[13] _pullback
@ ./array.jl:652 [inlined]
[14] _pullback(::Zygote.Context, ::typeof(Base._collect), ::UnitRange{Int64}, ::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}}, ::Base.EltypeUnknown, ::Base.SizeUnknown)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[15] _pullback
@ ./array.jl:602 [inlined]
[16] _pullback(ctx::Zygote.Context, f::typeof(collect), args::Base.Iterators.Flatten{Base.Generator{Vector{Any}, var"#40#42"}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[17] _pullback
@ ./REPL[474]:2 [inlined]
[18] _pullback(::Zygote.Context, ::var"#39#41")
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[19] pullback(f::Function, ps::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:247
[20] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:58
[21] top-level scope
@ REPL[474]:1
It's more subtle too. You need a Conv
+ Dense
:
julia> d = Chain(Dense(3, 3), Dense(3, 3))
Chain(Dense(3, 3), Dense(3, 3))
julia> gs = gradient(Flux.params(d)) do
ws = [w for l in Flux.modules(d) for w in regularized_params_(l)]
sum(sum(w) for w in ws)
end
ERROR: Compiling Tuple{Base.var"##depwarn#864", Bool, typeof(Base.depwarn), String, Symbol}: try/catch is not supported.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] instrument(ir::IRTools.Inner.IR)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:121
[3] #Primal#20
@ ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:202 [inlined]
[4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:315
[5] _lookup_grad(T::Type)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/emit.jl:101
[6] #s2993#1177
@ ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:34 [inlined]
[7] var"#s2993#1177"(T::Any, j::Any, Δ::Any)
@ Zygote ./none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
@ Core ./boot.jl:571
[9] Pullback
@ ./deprecated.jl:80 [inlined]
[10] (::typeof(∂(depwarn)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[11] Pullback
@ ~/.julia/packages/Flux/qp1gc/src/deprecations.jl:13 [inlined]
[12] (::typeof(∂(getproperty)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
[14] Pullback
@ ./REPL[457]:1 [inlined]
[15] (::typeof(∂(regularized_params_)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[16] Pullback
@ ./none:0 [inlined]
[17] (::typeof(∂(#44)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[18] Pullback
@ ./generator.jl:47 [inlined]
[19] (::typeof(∂(iterate)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[20] Pullback
@ ./iterators.jl:1093 [inlined]
[21] (::typeof(∂(iterate)))(Δ::Tuple{Nothing, Tuple{Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[22] Pullback
@ ./array.jl:761 [inlined]
[23] (::typeof(∂(grow_to!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[24] Pullback
@ ./array.jl:743 [inlined]
[25] (::typeof(∂(grow_to!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[26] Pullback
@ ./array.jl:652 [inlined]
[27] (::typeof(∂(_collect)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[28] Pullback
@ ./array.jl:602 [inlined]
[29] (::typeof(∂(collect)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[30] Pullback
@ ./REPL[476]:2 [inlined]
[31] (::typeof(∂(#43)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[32] (::Zygote.var"#69#70"{Params, typeof(∂(#43)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:252
[33] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:59
[34] top-level scope
@ REPL[476]:1
I think if you look at the stack trace, you'll see a call to typejoin
because the type of the Conv
and Dense
weights are not the same type.
Even smaller reproducer without Flux.modules
:
julia> gs = gradient(Flux.params(d)) do
ws = [w for l in d for w in regularized_params_(l)]
sum(sum(w) for w in ws)
end
ERROR: UndefVarError: S not defined
Stacktrace:
[1] (typeof(∂(λ)))(x::Tuple{typeof(∂(getproperty))})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:19
[2] _pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
[3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[4] _pullback
@ ./promotion.jl:87 [inlined]
[5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{Array{Float32, 4}}, ::Type{Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[6] _pullback
@ ./promotion.jl:149 [inlined]
[7] _pullback
@ ./array.jl:748 [inlined]
[8] _pullback(::Zygote.Context, ::typeof(Base.push_widen), ::Vector{Array{Float32, 4}}, ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[9] _pullback
@ ./array.jl:767 [inlined]
[10] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Array{Float32, 4}}, ::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}}, ::Tuple{Int64, Base.Generator{Vector{Array{Float32, 4}}, typeof(identity)}, Int64})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[11] _pullback
@ ./array.jl:743 [inlined]
[12] _pullback(::Zygote.Context, ::typeof(Base.grow_to!), ::Vector{Any}, ::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[13] _pullback
@ ./array.jl:652 [inlined]
[14] _pullback(::Zygote.Context, ::typeof(Base._collect), ::UnitRange{Int64}, ::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}}, ::Base.EltypeUnknown, ::Base.SizeUnknown)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[15] _pullback
@ ./array.jl:602 [inlined]
[16] _pullback(ctx::Zygote.Context, f::typeof(collect), args::Base.Iterators.Flatten{Base.Generator{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, var"#64#66"}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[17] _pullback
@ ./REPL[484]:2 [inlined]
[18] _pullback(::Zygote.Context, ::var"#63#65")
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[19] pullback(f::Function, ps::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:247
[20] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:58
[21] top-level scope
@ REPL[484]:1
I found that I needed different types for the elements in the generator, and I need nested generators (to force the call to collect
which in turn would trigger the typejoin
). For example, the following gets a different error:
julia> gs = gradient(Flux.params(d)) do
sum(sum(w) for m in d for w in regularized_params_(m))
end
ERROR: Compiling Tuple{Base.var"##depwarn#864", Bool, typeof(Base.depwarn), String, Symbol}: try/catch is not supported.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] instrument(ir::IRTools.Inner.IR)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:121
[3] #Primal#20
@ ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:202 [inlined]
[4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/reverse.jl:315
[5] _lookup_grad(T::Type)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/emit.jl:101
[6] #s2993#1177
@ ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:34 [inlined]
[7] var"#s2993#1177"(T::Any, j::Any, Δ::Any)
@ Zygote ./none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
@ Core ./boot.jl:571
[9] Pullback
@ ./deprecated.jl:80 [inlined]
[10] (::typeof(∂(depwarn)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[11] Pullback
@ ~/.julia/packages/Flux/qp1gc/src/deprecations.jl:13 [inlined]
[12] (::typeof(∂(getproperty)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
[14] Pullback
@ ./REPL[457]:1 [inlined]
[15] (::typeof(∂(regularized_params_)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[16] Pullback
@ ./none:0 [inlined]
[17] (::typeof(∂(#60)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[18] Pullback
@ ./reduce.jl:93 [inlined]
[19] (::typeof(∂(Base.MappingRF{var"#60#62", Base.FlatteningRF{Base.BottomRF{typeof(Base.add_sum)}}}(var"#60#62"(), Base.FlatteningRF{Base.BottomRF{typeof(Base.add_sum)}}(Base.BottomRF{typeof(Base.add_sum)}(Base.add_sum))))))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[20] Pullback
@ ./reduce.jl:62 [inlined]
[21] (::typeof(∂(_foldl_impl)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[22] Pullback
@ ./reduce.jl:48 [inlined]
[23] (::typeof(∂(foldl_impl)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[24] Pullback
@ ./reduce.jl:44 [inlined]
[25] (::typeof(∂(mapfoldl_impl)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[26] Pullback (repeats 2 times)
@ ./reduce.jl:160 [inlined]
[27] (::typeof(∂(mapfoldl)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[28] Pullback
@ ./reduce.jl:287 [inlined]
[29] (::typeof(∂(#mapreduce#218)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[30] Pullback
@ ./reduce.jl:287 [inlined]
[31] (::typeof(∂(mapreduce)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[32] Pullback
@ ./reduce.jl:501 [inlined]
[33] (::typeof(∂(#sum#221)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[34] Pullback
@ ./reduce.jl:501 [inlined]
[35] (::typeof(∂(sum)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[36] Pullback
@ ./reduce.jl:528 [inlined]
[37] (::typeof(∂(#sum#222)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[38] Pullback
@ ./reduce.jl:528 [inlined]
[39] (::typeof(∂(sum)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[40] Pullback
@ ./REPL[483]:2 [inlined]
[41] (::typeof(∂(#59)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[42] (::Zygote.var"#69#70"{Params, typeof(∂(#59)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:252
[43] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:59
[44] top-level scope
@ REPL[483]:1
Okay finally got a MWE:
julia> gs = gradient(params(d)) do
x = typejoin(Array{Float32, 4}, Array{Float32, 2})
return 1
end
ERROR: UndefVarError: S not defined
Stacktrace:
[1] (typeof(∂(λ)))(x::Tuple{typeof(∂(getproperty))})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:19
[2] _pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
[3] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::Type{Type{T}}, ::Val{:name})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[4] _pullback
@ ./promotion.jl:87 [inlined]
[5] _pullback(::Zygote.Context, ::typeof(typejoin), ::Type{Array{Float32, 4}}, ::Type{Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[6] _pullback
@ ./REPL[486]:2 [inlined]
[7] _pullback(::Zygote.Context, ::var"#67#68")
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[8] pullback(f::Function, ps::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:247
[9] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:58
[10] top-level scope
@ REPL[486]:1
Just opened https://github.com/FluxML/Zygote.jl/pull/947
Looks like we came to the same conclusion lol.
@darsnack added you as a co-author (hope you don't mind)
I updated the Manifest to use https://github.com/FluxML/Zygote.jl/pull/947 but now I am seeing a different error. (I updated the flux-0.12 branch so you can still use the replication instructions above).
ERROR: Compiling Tuple{Base.var"##depwarn#868", Bool, typeof(Base.depwarn), String, Symbol}: try/catch is not supported.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] instrument(ir::IRTools.Inner.IR)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/reverse.jl:121
[3] #Primal#20
@ ~/.julia/packages/Zygote/iSZne/src/compiler/reverse.jl:202 [inlined]
[4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/reverse.jl:315
[5] _lookup_grad(T::Type)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/emit.jl:101
[6] #s2996#1179
@ ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:34 [inlined]
[7] var"#s2996#1179"(T::Any, j::Any, Δ::Any)
@ Zygote ./none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
@ Core ./boot.jl:571
[9] Pullback
@ ./deprecated.jl:80 [inlined]
[10] (::typeof(∂(depwarn)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[11] Pullback
@ ~/.julia/packages/Flux/qp1gc/src/deprecations.jl:13 [inlined]
[12] (::typeof(∂(getproperty)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
[14] Pullback
@ ~/AlphaZero.jl/src/networks/flux.jl:113 [inlined]
[15] (::typeof(∂(regularized_params_)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[16] Pullback
@ ./none:0 [inlined]
[17] (::typeof(∂(#5)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[18] Pullback
@ ./generator.jl:47 [inlined]
[19] (::typeof(∂(iterate)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[20] Pullback
@ ./iterators.jl:1097 [inlined]
[21] (::typeof(∂(iterate)))(Δ::Tuple{Nothing, Tuple{Nothing, Nothing, Nothing}})
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[22] Pullback
@ ./array.jl:770 [inlined]
[23] (::typeof(∂(grow_to!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[24] Pullback
@ ./array.jl:768 [inlined]
[25] (::typeof(∂(grow_to!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[26] Pullback
@ ./array.jl:743 [inlined]
[27] (::typeof(∂(grow_to!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[28] Pullback
@ ./array.jl:652 [inlined]
[29] (::typeof(∂(_collect)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[30] Pullback
@ ./array.jl:602 [inlined]
[31] (::typeof(∂(collect)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[32] Pullback
@ ~/AlphaZero.jl/src/networks/flux.jl:117 [inlined]
[33] (::typeof(∂(regularized_params)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[34] Pullback
@ ~/AlphaZero.jl/src/learning.jl:75 [inlined]
[35] (::typeof(∂(losses)))(Δ::Tuple{Float32, Nothing, Nothing, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[36] Pullback
@ ~/AlphaZero.jl/src/learning.jl:122 [inlined]
[37] (::typeof(∂(λ)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[38] (::Zygote.var"#178#179"{Tuple{NTuple{5, Nothing}}, typeof(∂(λ))})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/lib/lib.jl:194
[39] #1686#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[40] Pullback
@ ~/AlphaZero.jl/src/networks/flux.jl:82 [inlined]
[41] (::typeof(∂(λ)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface2.jl:0
[42] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/iSZne/src/compiler/interface.jl:252
[43] lossgrads(f::Function, args::Zygote.Params)
@ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:73
[44] train!(callback::AlphaZero.var"#109#111"{Vector{Float32}}, nn::ResNet, opt::Adam, loss::Function, data::Base.Iterators.Take{Base.Iterators.Stateful{Base.Iterators.Flatten{Base.Generator{Base.Iterators.Repeated{Nothing}, AlphaZero.Util.var"#12#13"{AlphaZero.var"#106#108"{ResNet}, Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}, Int64, Bool}}}, Tuple{NTuple{5, Any}, Tuple{Nothing, Base.Generator{Vector{Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, AlphaZero.Util.var"#9#11"{AlphaZero.var"#106#108"{ResNet}}}, Int64}}}}, n::Int64)
@ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:81
[45] batch_updates!(tr::AlphaZero.Trainer, n::Int64)
@ AlphaZero ~/AlphaZero.jl/src/learning.jl:125
[46] macro expansion
@ ./timing.jl:356 [inlined]
[47] learning_step!(env::Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}, handler::Session{Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}})
@ AlphaZero ~/AlphaZero.jl/src/training.jl:224
[48] test_grad_updates(exp::Experiment; num_games::Int64)
@ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:17
[49] test_grad_updates
@ ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:10 [inlined]
[50] #test_grad_updates#21
@ ~/AlphaZero.jl/src/scripts/scripts.jl:57 [inlined]
[51] test_grad_updates(s::String)
@ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/scripts.jl:57
I addressed that in https://github.com/JuliaDiff/ChainRules.jl/pull/398.
Should be closed in https://github.com/JuliaDiff/ChainRules.jl/pull/398
Now it works if I define this:
function Network.regularized_params(net::FluxNetwork)
return (w for l in Flux.modules(net) for w in regularized_params_(l))
end
Just so you know, I still have an error if I use this definition instead (returning an array instead of a generator):
function Network.regularized_params(net::FluxNetwork)
return [w for l in Flux.modules(net) for w in regularized_params_(l)]
end
I am not sure whether or not this should be considered a bug. If not, how should I modify the second example to make it work? Would I need to use a @nograd
macro somewhere?
ERROR: Mutating arrays is not supported
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.var"#405#406")(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/lib/array.jl:61
[3] (::Zygote.var"#2266#back#407"{Zygote.var"#405#406"})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] Pullback
@ ./array.jl:977 [inlined]
[5] (::typeof(∂(append!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[6] Pullback
@ ./array.jl:753 [inlined]
[7] (::typeof(∂(push_widen)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[8] Pullback
@ ./array.jl:767 [inlined]
[9] (::typeof(∂(grow_to!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[10] Pullback
@ ./array.jl:743 [inlined]
[11] (::typeof(∂(grow_to!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[12] Pullback
@ ./array.jl:652 [inlined]
[13] (::typeof(∂(_collect)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[14] Pullback
@ ./array.jl:602 [inlined]
[15] (::typeof(∂(collect)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[16] Pullback
@ ~/AlphaZero.jl/src/networks/flux.jl:117 [inlined]
[17] (::typeof(∂(regularized_params)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[18] Pullback
@ ~/AlphaZero.jl/src/learning.jl:75 [inlined]
[19] (::typeof(∂(losses)))(Δ::Tuple{Float32, Nothing, Nothing, Nothing, Nothing})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[20] Pullback
@ ~/AlphaZero.jl/src/learning.jl:122 [inlined]
[21] (::typeof(∂(λ)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[22] (::Zygote.var"#178#179"{Tuple{NTuple{5, Nothing}}, typeof(∂(λ))})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/lib/lib.jl:194
[23] #1686#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[24] Pullback
@ ~/AlphaZero.jl/src/networks/flux.jl:82 [inlined]
[25] (::typeof(∂(λ)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[26] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:252
[27] lossgrads(f::Function, args::Zygote.Params)
@ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:73
[28] train!(callback::AlphaZero.var"#109#111"{Vector{Float32}}, nn::ResNet, opt::Adam, loss::Function, data::Base.Iterators.Take{Base.Iterators.Stateful{Base.Iterators.Flatten{Base.Generator{Base.Iterators.Repeated{Nothing}, AlphaZero.Util.var"#12#13"{AlphaZero.var"#106#108"{ResNet}, Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}, Int64, Bool}}}, Tuple{NTuple{5, Any}, Tuple{Nothing, Base.Generator{Vector{Tuple{Matrix{Float32}, Array{Float32, 4}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, AlphaZero.Util.var"#9#11"{AlphaZero.var"#106#108"{ResNet}}}, Int64}}}}, n::Int64)
@ AlphaZero.FluxLib ~/AlphaZero.jl/src/networks/flux.jl:81
[29] batch_updates!(tr::AlphaZero.Trainer, n::Int64)
@ AlphaZero ~/AlphaZero.jl/src/learning.jl:125
[30] macro expansion
@ ./timing.jl:356 [inlined]
[31] learning_step!(env::Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}, handler::Session{Env{AlphaZero.Examples.ConnectFour.GameSpec, ResNet, NamedTuple{(:board, :curplayer), Tuple{StaticArrays.SMatrix{7, 6, UInt8, 42}, UInt8}}}})
@ AlphaZero ~/AlphaZero.jl/src/training.jl:224
[32] test_grad_updates(exp::Experiment; num_games::Int64)
@ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:17
[33] test_grad_updates
@ ~/AlphaZero.jl/src/scripts/test_grad_updates.jl:10 [inlined]
[34] #test_grad_updates#21
@ ~/AlphaZero.jl/src/scripts/scripts.jl:57 [inlined]
[35] test_grad_updates(s::String)
@ AlphaZero.Scripts ~/AlphaZero.jl/src/scripts/scripts.jl:57
[36] top-level scope
@ REPL[3]:1
Interesting is this with the chain rules update?
Yes, this is with the chain rule update.
So this is in fact doing array mutation internally, so I'm comfortable to say that Zygote is indeed showing the expected behavior, however, nested generators are something that we had working before. For now I'd return a literal generator and look into the code gen separately
Interesting update:
This issue is about making sure that code like this works with Flux:
gs = gradient(Flux.params(d)) do
ds = Flux.modules(d)
sum(l(x) for l in ds) |> sum
end
However, the code above appears to be very slow. Indeed, I achieved a 3x speedup (and a tenfold reduction in the number of GPU allocs) in the backprop phase of AlphaZero.jl by essentially replacing the code above by:
ds = Flux.modules(d)
gs = gradient(Flux.params(d)) do
sum(l(x) for l in ds) |> sum
end
For more details, see the exact commit: : https://github.com/jonathan-laurent/AlphaZero.jl/commit/b8bac93c9048e84ebd1aee6e51fad5b58dae357f
Here is the output of CUDA.@time
before and after I made the change in AlphaZero.jl:
BEFORE:
12.463815 seconds (44.18 M CPU allocations: 1.117 GiB, 4.67% gc time) (401.32 k GPU allocations: 314.618 GiB, 18.44% gc time of which 34.41% spent allocating)
AFTER:
4.198313 seconds (11.60 M CPU allocations: 451.503 MiB, 7.84% gc time) (45.81 k GPU allocations: 314.617 GiB, 8.46% gc time of which 38.23% spent allocating)
Although it is understandable for the second version to be faster, I definitely did not expect such a gap. Note that this is possibly the result of a recent regression as I think I would have noticed this before otherwise.
Note that this is possibly the result of a recent regression as I think I would have noticed this before otherwise.
How recent? (see edit) So, I am not surprised that they don't perform the same, but it's good to know that it is so slow. We definitely want to make the first code example performant, cause that's how I'd expect Flux.modules
is new (as of v0.12) and uses a very different implementation (based on Functors.jl) than the for-loop it replaces in AlphaZero.jl.Flux.modules
to be used most of the time.
Unless by recent you mean within the last couple weeks (i.e. comparing Flux.modules
vs. Flux.modules
w/ some updates). In which case, Zygote would be the first place to look.
Maybe a good cross-check would be to replace Flux.modules
with some other generator?
EDIT:
Actually they are very similar.
Before Flux v0.12, I was relying on my own implementation of modules
.
Could it be that fcollect
uses a vector as a cache and that's slower in AD? Not sure what else could be a factor.
It has been @nograd
d already.
Ah right, disregard that then.
I played around with the examples above on CPU and GPU. Not sure why using modules
inside the gradient context has such a crushing performance impact on CPU-side allocations given that modules
is @nograd
.
d = Chain((Dense(32, 32) for i in 1:5)...)
x = rand(Float32, 32, 8)
f1(m, x) = gradient(Flux.params(m)) do
ds = Flux.modules(m)
sum(l(x) for l in ds) |> sum
end
function f2(m, x)
ds = Flux.modules(m)
gradient(Flux.params(m)) do
sum(l(x) for l in ds) |> sum
end
end
julia> @timev f1(d, x)
24.901731 seconds (63.67 M allocations: 3.610 GiB, 4.22% gc time)
elapsed time (ns): 24901730841
gc time (ns): 1051987239
bytes allocated: 3875757116
pool allocs: 63656612
non-pool GC allocs:15360
malloc() calls: 88
realloc() calls: 15
GC pauses: 53
full collections: 2
julia> @timev f2(d, x)
1.086023 seconds (2.15 M allocations: 119.429 MiB, 2.75% gc time, 98.94% compilation time)
elapsed time (ns): 1086023325
gc time (ns): 29910632
bytes allocated: 125229937
pool allocs: 2152312
non-pool GC allocs:161
GC pauses: 3
julia> @btime f2($d, $x)
415.283 μs (2140 allocations: 210.67 KiB)
julia> @btime f1($d, $x)
414.055 μs (2146 allocations: 214.48 KiB)
FWIW, running the above example:
julia> @timev f1(d, x)
12.528920 seconds (61.88 M allocations: 3.262 GiB, 5.23% gc time, 99.97% compilation time)
elapsed time (ns): 12528919792
gc time (ns): 655876707
bytes allocated: 3502653566
pool allocs: 61852547
non-pool GC allocs:23328
malloc() calls: 103
realloc() calls: 20
GC pauses: 42
full collections: 1
Grads(...)
julia> @timev f2(d, x)
0.456313 seconds (2.03 M allocations: 103.419 MiB, 5.52% gc time, 99.67% compilation time)
elapsed time (ns): 456313042
gc time (ns): 25197167
bytes allocated: 108442549
pool allocs: 2025083
non-pool GC allocs:271
GC pauses: 2
Grads(...)
julia> @btime f2($d, $x)
min 171.500 μs, mean 209.663 μs (1631 allocations, 309.03 KiB. GC mean 11.82%)
Grads(...)
julia> @btime f1($d, $x)
min 173.167 μs, mean 209.863 μs (1636 allocations, 312.83 KiB. GC mean 11.27%)
Grads(...)
And after re-starting, running them in the opposite order:
julia> @timev f2(d, x)
12.708308 seconds (61.94 M allocations: 3.266 GiB, 5.74% gc time, 99.98% compilation time)
elapsed time (ns): 12708308041
gc time (ns): 729664123
bytes allocated: 3506520558
pool allocs: 61916466
non-pool GC allocs:23354
malloc() calls: 104
realloc() calls: 20
GC pauses: 39
full collections: 2
Grads(...)
julia> @timev f1(d, x)
0.447610 seconds (1.98 M allocations: 100.910 MiB, 2.59% gc time, 99.69% compilation time)
elapsed time (ns): 447609750
gc time (ns): 11576334
bytes allocated: 105811871
pool allocs: 1980090
non-pool GC allocs:250
GC pauses: 1
Grads(...)
(@v1.8) pkg> st Flux Zygote
Status `~/.julia/environments/v1.8/Project.toml`
[587475ba] Flux v0.12.8
[e88e6eb3] Zygote v0.6.32
So there's 12s of startup time on whichever is first, but no obvious effect of where Flux.modules(m)
is called.
With the benefit of hindsight, the first run is basically measuring any AD compilation overhead, while the second can benefit from caching. It also looks like the original issue is resolved, so worth opening a new one if anything persists.
I am trying to update AlphaZero.jl so that it works with Flux v0.12 but I am stuck on the following Zygote error:
The bug happens on both Flux@0.12.1 and Flux#master.
Replication instructions
To replicate, you can run the following using Julia 1.6:
Full stacktrace