FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

BUG: getindex(::Dict) #760

Closed jamblejoe closed 1 year ago

jamblejoe commented 4 years ago

The following code

function f(x)
        d = Dict()

        for i in 1:4
            push!(d, i=>i^x)
        end

        sum(values(d))
    end

    @show f(3)
    gradient(f, 3)

errors with

MethodError: no method matching getindex(::Dict{Any,Any})
Closest candidates are:
  getindex(::Dict{K,V}, !Matched::Any) where {K, V} at dict.jl:465
  getindex(::AbstractDict, !Matched::Any) at abstractdict.jl:489
  getindex(::AbstractDict, !Matched::Any, !Matched::Any, !Matched::Any...) at abstractdict.jl:499

Stacktrace:
 [1] (::Zygote.var"#back#187"{:vals,Zygote.Context,Dict{Any,Any},Array{Any,1}})(::Array{Union{Nothing, Int64},1}) at C:\Users\Goran\.julia\packages\Zygote\seGHk\src\lib\lib.jl:207
 [2] (::Zygote.var"#1743#back#188"{Zygote.var"#back#187"{:vals,Zygote.Context,Dict{Any,Any},Array{Any,1}}})(::Array{Union{Nothing, Int64},1}) at C:\Users\Goran\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
 [3] iterate at .\dict.jl:682 [inlined]
 [4] (::typeof(∂(iterate)))(::Tuple{Int64,Nothing}) at C:\Users\Goran\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
 [5] _foldl_impl at .\reduce.jl:60 [inlined]
 [6] (::typeof(∂(_foldl_impl)))(::Int64) at C:\Users\Goran\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
 [7] foldl_impl at .\reduce.jl:48 [inlined]
 [8] (::typeof(∂(foldl_impl)))(::Int64) at C:\Users\Goran\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
 [9] mapfoldl_impl at .\reduce.jl:44 [inlined]
 [10] (::typeof(∂(mapfoldl_impl)))(::Int64) at C:\Users\Goran\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
 [11] #mapfoldl#204 at .\reduce.jl:160 [inlined]
 [12] (::typeof(∂(#mapfoldl#204)))(::Int64) at C:\Users\Goran\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
 [13] mapfoldl at .\reduce.jl:160 [inlined]
 [14] (::typeof(∂(mapfoldl)))(::Int64) at C:\Users\Goran\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
 [15] #mapreduce#208 at .\reduce.jl:287 [inlined]
 [16] (::typeof(∂(#mapreduce#208)))(::Int64) at C:\Users\Goran\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
 [17] mapreduce at .\reduce.jl:287 [inlined]
 [18] (::typeof(∂(mapreduce)))(::Int64) at C:\Users\Goran\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
 [19] sum at .\reduce.jl:494 [inlined]
 [20] sum at .\reduce.jl:511 [inlined]
 [21] (::typeof(∂(sum)))(::Int64) at C:\Users\Goran\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
 [22] f at .\In[88]:10 [inlined]
 [23] (::typeof(∂(f)))(::Int64) at C:\Users\Goran\.julia\packages\Zygote\seGHk\src\compiler\interface2.jl:0
 [24] (::Zygote.var"#41#42"{typeof(∂(f))})(::Int64) at C:\Users\Goran\.julia\packages\Zygote\seGHk\src\compiler\interface.jl:45
 [25] gradient(::Function, ::Int64) at C:\Users\Goran\.julia\packages\Zygote\seGHk\src\compiler\interface.jl:54
 [26] top-level scope at In[88]:14
 [27] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1091

I assume that this is not intended.

gxyd commented 4 years ago

I assume that this is not intended.

I agree, I expected the value to be:

julia> (1^3)*log(1) + (2^3)*log(2) + (3^3)*log(3) + (4^3)*log(4)
123.93054835019151

If noone is working on it, can I may be try fixing this issue?

DhairyaLGandhi commented 4 years ago

Could you try https://github.com/DhairyaLGandhi/Zygote.jl/tree/dg/iddict

gxyd commented 4 years ago

I tried it on that branch, it seems like its giving me the same error:

ERROR: MethodError: no method matching getindex(::Dict{Any,Any})
Closest candidates are:
  getindex(::Dict{K,V}, ::Any) where {K, V} at dict.jl:476
  getindex(::AbstractDict, ::Any) at abstractdict.jl:469
  getindex(::AbstractDict, ::Any, ::Any, ::Any...) at abstractdict.jl:478
Stacktrace:
 [1] (::Zygote.var"#back#155"{:vals,Zygote.Context,Dict{Any,Any},Array{Any,1}})(::Array{Union{Nothing, Int64},1}) at /Users/gaurav/.julia/dev/Zygote/src/lib/lib.jl:204
 [2] (::Zygote.var"#1686#back#156"{Zygote.var"#back#155"{:vals,Zygote.Context,Dict{Any,Any},Array{Any,1}}})(::Array{Union{Nothing, Int64},1}) at /Users/gaurav/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [3] iterate at ./dict.jl:692 [inlined]
 [4] (::typeof(∂(iterate)))(::Tuple{Int64,Nothing}) at /Users/gaurav/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [5] _foldl_impl at ./reduce.jl:57 [inlined]
 [6] (::typeof(∂(_foldl_impl)))(::Int64) at /Users/gaurav/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [7] foldl_impl at ./reduce.jl:45 [inlined]
 [8] (::typeof(∂(foldl_impl)))(::Int64) at /Users/gaurav/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [9] mapfoldl_impl at ./reduce.jl:41 [inlined]
 [10] (::typeof(∂(mapfoldl_impl)))(::Int64) at /Users/gaurav/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [11] #mapfoldl#189 at ./reduce.jl:157 [inlined]
 [12] (::typeof(∂(#mapfoldl#189)))(::Int64) at /Users/gaurav/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [13] mapfoldl at ./reduce.jl:157 [inlined]
 [14] (::typeof(∂(mapfoldl)))(::Int64) at /Users/gaurav/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [15] #mapreduce#193 at ./reduce.jl:283 [inlined]
 [16] (::typeof(∂(#mapreduce#193)))(::Int64) at /Users/gaurav/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [17] mapreduce at ./reduce.jl:283 [inlined]
 [18] (::typeof(∂(mapreduce)))(::Int64) at /Users/gaurav/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [19] sum at ./reduce.jl:486 [inlined]
 [20] sum at ./reduce.jl:503 [inlined]
 [21] (::typeof(∂(sum)))(::Int64) at /Users/gaurav/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [22] f at ./REPL[3]:6 [inlined]
 [23] (::typeof(∂(f)))(::Int64) at /Users/gaurav/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [24] (::Zygote.var"#41#42"{typeof(∂(f))})(::Int64) at /Users/gaurav/.julia/dev/Zygote/src/compiler/interface.jl:45
 [25] gradient(::Function, ::Int64) at /Users/gaurav/.julia/dev/Zygote/src/compiler/interface.jl:54
 [26] top-level scope at REPL[4]:1
gxyd commented 4 years ago

One thing I can't understand is in general, while stepping through a debugger, why certain variables aren't available when showing lowered code like this:

An example below, when I tried to debug the above:

1|debug> n
In _pullback(ctx, f, args) at /Users/gaurav/.julia/dev/Zygote/src/compiler/interface2.jl:13
 1  1 ─      $(Expr(:meta, :inline))
 2  │   %2 = (getfield)(args, 1)
 3  │   %3 = (ZygoteRules._pullback)(ctx, Base.ValueIterator, %2)
 4  │   %4 = (getindex)(%3, 1)
 5  │   %5 = (getindex)(%3, 2)
 6  │   %6 = (tuple)(%5)
>7  │   %7 = (typeof(∂(values)))(%6)
 8  │   %8 = (tuple)(%4, %7)
 9  └──      return %8

About to run: (typeof(∂(values)))((∂(Base.ValueIterator),))
1|julia> (tuple)(∂(Base.ValueIterator))
ERROR: UndefVarError: ∂ not defined
Stacktrace:
 [1] top-level scope at REPL[17]:1
 [2] eval at ./boot.jl:331 [inlined]
 [3] eval_code(::JuliaInterpreter.Frame, ::Expr) at /Users/gaurav/.julia/packages/JuliaInterpreter/CPmYX/src/utils.jl:595
 [4] eval_code(::JuliaInterpreter.Frame, ::String) at /Users/gaurav/.julia/packages/JuliaInterpreter/CPmYX/src/utils.jl:572
 [5] _eval_code(::JuliaInterpreter.Frame, ::String) at /Users/gaurav/.julia/packages/Debugger/Xr8bu/src/repl.jl:202
 [6] (::Debugger.var"#27#29"{Debugger.DebuggerState})(::REPL.LineEdit.MIState, ::Base.GenericIOBuffer{Array{UInt8,1}}, ::Bool) at /Users/gaurav/.julia/packages/Debugger/Xr8bu/src/repl.jl:185
 [7] #invokelatest#1 at ./essentials.jl:712 [inlined]
 [8] invokelatest at ./essentials.jl:711 [inlined]
 [9] run_interface(::REPL.Terminals.TextTerminal, ::REPL.LineEdit.ModalInterface, ::REPL.LineEdit.MIState) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.4/REPL/src/LineEdit.jl:2354
 [10] run_interface(::REPL.Terminals.TextTerminal, ::REPL.LineEdit.ModalInterface) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.4/REPL/src/LineEdit.jl:2348
 [11] RunDebugger(::JuliaInterpreter.Frame, ::Nothing, ::Nothing; initial_continue::Bool) at /Users/gaurav/.julia/packages/Debugger/Xr8bu/src/repl.jl:158
 [12] RunDebugger at /Users/gaurav/.julia/packages/Debugger/Xr8bu/src/repl.jl:4 [inlined] (repeats 2 times)
 [13] top-level scope at /Users/gaurav/.julia/packages/Debugger/Xr8bu/src/Debugger.jl:126

Though it explicitly states: About to run: (typeof(∂(values)))((∂(Base.ValueIterator),)), that means has already been defined, but entering REPL and trying to explicitly print raises UndefVarError.

rlrs commented 4 years ago

Probably related to #725, at least to the extent that dictionaries are just very hard to get to work. I've attempted a bit of debugging as seen in #725, but don't think I understand enough to solve the dictionary issues myself. Wouldn't mind contributing a fix if I got some advice as to how.

rlrs commented 4 years ago

I'm trying to solve these issues still, and here's a status of what I know to keep this issue alive. In this case, sum calls iterate which accesses the dictionary's fields .vals and .keys, which uses the adjoint defined in https://github.com/FluxML/Zygote.jl/blob/9dc33c5b402e16f23510877b46c53ca04259ca0f/src/lib/lib.jl#L196-L209 for structs.

The thing is, grad_mutand the associated __context__ are also used in defining some dictionary-specific adjoints: https://github.com/FluxML/Zygote.jl/blob/9dc33c5b402e16f23510877b46c53ca04259ca0f/src/lib/base.jl#L22-L45

Note that literal_getproperty assumes that grad_mutreturns a reference to a NamedTuple. Would it be preferable for literal_getproperty to work for dictionaries, or should we implement a dictionary-specific one? One problem I see here is that .vals is a vector with #undef entries, so at the very least that needs to be handled in e.g. accum.

CarloLucibello commented 1 year ago

for some reason the example in the OP works now. Will close and add tests

jamblejoe commented 1 year ago

Can confirm that it works for v0.6.49