FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

BoundsError calling Flux.reset! #1297

Open mentics opened 2 years ago

mentics commented 2 years ago

ERROR: BoundsError: attempt to access Tuple{} at index [0] thrown from Zygote code when calling Flux.reset! in a Flux loss function.

Julia 1.8.0 Zygote 0.6.45 Flux v0.13.5

Test case:

using Flux
function test()
    model = Dense(1 => 1)
    params = Flux.params(model)
    function loss(x, y)
        Flux.reset!(model)
        Flux.Losses.mse(model(x), y)
    end
    Flux.train!(loss, params, [([1.0],[1.0])], Descent())
end

Stacktrace:

ERROR: BoundsError: attempt to access Tuple{} at index [0]
Stacktrace:
  [1] getindex(t::Tuple, i::Int64)
    @ Base .\tuple.jl:29
  [2] last(a::Tuple{})
    @ Base .\abstractarray.jl:479
  [3] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(foldl), op::Base.var"#57#58"{typeof(Flux.reset!)}, x::Tuple{}; init::Nothing)
    @ ChainRules C:\Users\joel\.julia\packages\ChainRules\fgVxV\src\rulesets\Base\mapreduce.jl:448
  [4] chain_rrule_kw(::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::Function, ::NamedTuple{(:init,), Tuple{Nothing}}, ::Function, ::Function, ::Vararg{Any})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\chainrules.jl:230
  [5] macro expansion
    @ C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0 [inlined]
  [6] _pullback(::Zygote.Context{true}, ::Base.var"#foldl##kw", ::NamedTuple{(:init,), Tuple{Nothing}}, ::typeof(foldl), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:9
  [7] _pullback
    @ .\tuple.jl:555 [inlined]
  [8] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::Tuple{})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
  [9] _pullback
    @ C:\Users\joel\.julia\packages\Flux\EXOFx\src\layers\recurrent.jl:180 [inlined]
 [10] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Matrix{Float32})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
 [11] _pullback
    @ .\abstractarray.jl:2774 [inlined]
 [12] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, typeof(identity)}})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
 [13] _pullback
    @ C:\Users\joel\.julia\packages\Flux\EXOFx\src\layers\recurrent.jl:180 [inlined]
 [14] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
 [15] _pullback
    @ C:\data\julia\journey\modules\lev2-util\ml\TryFlux.jl:23 [inlined]
 [16] _pullback(::Zygote.Context{true}, ::TryFlux.var"#loss#21"{Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
 [17] _apply(::Function, ::Vararg{Any})
    @ Core .\boot.jl:816
 [18] adjoint
    @ C:\Users\joel\.julia\packages\Zygote\qGFGD\src\lib\lib.jl:203 [inlined]
 [19] _pullback
    @ C:\Users\joel\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
 [20] _pullback
    @ C:\Users\joel\.julia\packages\Flux\EXOFx\src\optimise\train.jl:120 [inlined]
 [21] _pullback(::Zygote.Context{true}, ::Flux.Optimise.var"#37#40"{TryFlux.var"#loss#21"{Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, Tuple{Vector{Float64}, Vector{Float64}}})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
 [22] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface.jl:373
 [23] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface.jl:96
 [24] macro expansion
    @ C:\Users\joel\.julia\packages\Flux\EXOFx\src\optimise\train.jl:119 [inlined]
 [25] macro expansion
    @ C:\Users\joel\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:328 [inlined]
 [26] train!(loss::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, data::Vector{Tuple{Vector{Float64}, Vector{Float64}}}, opt::Flux.Optimise.Descent; cb::Flux.Optimise.var"#38#41")
    @ Flux.Optimise C:\Users\joel\.julia\packages\Flux\EXOFx\src\optimise\train.jl:117
 [27] train!
    @ C:\Users\joel\.julia\packages\Flux\EXOFx\src\optimise\train.jl:113 [inlined]
 [28] test()
    @ TryFlux C:\data\julia\journey\modules\lev2-util\ml\TryFlux.jl:26
 [29] top-level scope
    @ REPL[12]:1

Removing the call to Flux.reset! removes the error.

mcabbott commented 2 years ago

This is from inside the foldl rule, on an empty tuple. Which comes from this foreach on a leaf node:

reset!(m::Recur) = (m.state = m.cell.state0)
reset!(m) = foreach(reset!, functor(m)[1])
julia> Functors.functor(Dense(1 => 1).weight)[1]
()

Still fails with https://github.com/JuliaDiff/ChainRules.jl/pull/569 with this stacktrace:

julia> test()
ERROR: BoundsError: attempt to access Tuple{} at index [0]
Stacktrace:
  [1] getindex(t::Tuple, i::Int64)
    @ Base ./tuple.jl:29
  [2] last(a::Tuple{})
    @ Base ./abstractarray.jl:500
  [3] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::Base.var"#57#58"{typeof(Flux.reset!)}, init::Base._InitialValue, x::Tuple{Nothing})
    @ ChainRules ~/.julia/packages/ChainRules/fK4AU/src/rulesets/Base/mapreduce.jl:465
  [4] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::Base.var"#57#58"{typeof(Flux.reset!)}, init::Nothing, x::Tuple{})
    @ ChainRules ~/.julia/packages/ChainRules/fK4AU/src/rulesets/Base/mapreduce.jl:488
  [5] chain_rrule
    @ ~/.julia/packages/Zygote/qGFGD/src/compiler/chainrules.jl:218 [inlined]
  [6] macro expansion
    @ ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0 [inlined]
  [7] _pullback(::Zygote.Context{true}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Nothing, ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:9
  [8] _pullback
    @ ./reduce.jl:170 [inlined]
  [9] _pullback(::Zygote.Context{true}, ::Base.var"##mapfoldl#286", ::Nothing, ::typeof(mapfoldl), ::typeof(identity), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [10] _pullback
    @ ./reduce.jl:170 [inlined]
 [11] _pullback(::Zygote.Context{true}, ::Base.var"#mapfoldl##kw", ::NamedTuple{(:init,), Tuple{Nothing}}, ::typeof(mapfoldl), ::typeof(identity), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [12] _pullback
    @ ./reduce.jl:193 [inlined]
 [13] _pullback(::Zygote.Context{true}, ::Base.var"##foldl#287", ::Base.Pairs{Symbol, Nothing, Tuple{Symbol}, NamedTuple{(:init,), Tuple{Nothing}}}, ::typeof(foldl), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [14] _pullback
    @ ./reduce.jl:193 [inlined]
 [15] _pullback(::Zygote.Context{true}, ::Base.var"#foldl##kw", ::NamedTuple{(:init,), Tuple{Nothing}}, ::typeof(foldl), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [16] _pullback
    @ ./tuple.jl:602 [inlined]
 [17] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [18] _pullback
    @ ~/.julia/packages/Flux/EXOFx/src/layers/recurrent.jl:180 [inlined]
 [19] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [20] _pullback
    @ ./abstractarray.jl:3036 [inlined]
 [21] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, typeof(identity)}})
    @ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
 [22] _pullback
    @ ~/.julia/packages/Flux/EXOFx/src/layers/recurrent.jl:180 [inlined]
 [23] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})

One possible fix is:

julia> ChainRulesCore.@non_differentiable foreach(f, ::Tuple{})

julia> Zygote.refresh()

julia> test()

More generally should Flux be differentiating inside reset! at all?

ToucheSir commented 2 years ago

More generally should Flux be differentiating inside reset! at all?

My understanding of https://github.com/FluxML/Flux.jl/pull/808#issuecomment-510864610 is that it's intentional to allow initial state to be trainable, but perhaps there's another way for us to make that work.