Open mentics opened 2 years ago
This is from inside the foldl
rule, on an empty tuple. Which comes from this foreach
on a leaf node:
reset!(m::Recur) = (m.state = m.cell.state0)
reset!(m) = foreach(reset!, functor(m)[1])
julia> Functors.functor(Dense(1 => 1).weight)[1]
()
Still fails with https://github.com/JuliaDiff/ChainRules.jl/pull/569 with this stacktrace:
julia> test()
ERROR: BoundsError: attempt to access Tuple{} at index [0]
Stacktrace:
[1] getindex(t::Tuple, i::Int64)
@ Base ./tuple.jl:29
[2] last(a::Tuple{})
@ Base ./abstractarray.jl:500
[3] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::Base.var"#57#58"{typeof(Flux.reset!)}, init::Base._InitialValue, x::Tuple{Nothing})
@ ChainRules ~/.julia/packages/ChainRules/fK4AU/src/rulesets/Base/mapreduce.jl:465
[4] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::Base.var"#57#58"{typeof(Flux.reset!)}, init::Nothing, x::Tuple{})
@ ChainRules ~/.julia/packages/ChainRules/fK4AU/src/rulesets/Base/mapreduce.jl:488
[5] chain_rrule
@ ~/.julia/packages/Zygote/qGFGD/src/compiler/chainrules.jl:218 [inlined]
[6] macro expansion
@ ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0 [inlined]
[7] _pullback(::Zygote.Context{true}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Nothing, ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:9
[8] _pullback
@ ./reduce.jl:170 [inlined]
[9] _pullback(::Zygote.Context{true}, ::Base.var"##mapfoldl#286", ::Nothing, ::typeof(mapfoldl), ::typeof(identity), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[10] _pullback
@ ./reduce.jl:170 [inlined]
[11] _pullback(::Zygote.Context{true}, ::Base.var"#mapfoldl##kw", ::NamedTuple{(:init,), Tuple{Nothing}}, ::typeof(mapfoldl), ::typeof(identity), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[12] _pullback
@ ./reduce.jl:193 [inlined]
[13] _pullback(::Zygote.Context{true}, ::Base.var"##foldl#287", ::Base.Pairs{Symbol, Nothing, Tuple{Symbol}, NamedTuple{(:init,), Tuple{Nothing}}}, ::typeof(foldl), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[14] _pullback
@ ./reduce.jl:193 [inlined]
[15] _pullback(::Zygote.Context{true}, ::Base.var"#foldl##kw", ::NamedTuple{(:init,), Tuple{Nothing}}, ::typeof(foldl), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[16] _pullback
@ ./tuple.jl:602 [inlined]
[17] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[18] _pullback
@ ~/.julia/packages/Flux/EXOFx/src/layers/recurrent.jl:180 [inlined]
[19] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[20] _pullback
@ ./abstractarray.jl:3036 [inlined]
[21] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, typeof(identity)}})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[22] _pullback
@ ~/.julia/packages/Flux/EXOFx/src/layers/recurrent.jl:180 [inlined]
[23] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
One possible fix is:
julia> ChainRulesCore.@non_differentiable foreach(f, ::Tuple{})
julia> Zygote.refresh()
julia> test()
More generally should Flux be differentiating inside reset!
at all?
More generally should Flux be differentiating inside
reset!
at all?
My understanding of https://github.com/FluxML/Flux.jl/pull/808#issuecomment-510864610 is that it's intentional to allow initial state to be trainable, but perhaps there's another way for us to make that work.
ERROR: BoundsError: attempt to access Tuple{} at index [0]
thrown from Zygote code when calling Flux.reset! in a Flux loss function.Julia 1.8.0 Zygote 0.6.45 Flux v0.13.5
Test case:
Stacktrace:
Removing the call to Flux.reset! removes the error.