FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 212 forks source link

Gradient fails for Dict constructed with a generator or a vector of pairs #1293

Open Kolaru opened 2 years ago

Kolaru commented 2 years ago

When taking the gradient of a function that uses a dict constructor using a generator or a list, it fails because there is a try/catch block somewhere.

e.g. for a generator

julia> using Zygote

julia> function f(x)
              a = Dict(c => x for c in 1:3)
              return a[1]
          end
f (generic function with 1 method)

julia> Zygote.gradient(f, 2.0)
ERROR: Compiling Tuple{Type{Dict}, Base.Generator{UnitRange{Int64}, var"#3#4"{Float64}}}: try/catch is not supported.

It works however if I slurp the vector

julia> function f(x)
              a = Dict([c => x for c in 1:3]...)
              return a[1]
          end
f (generic function with 1 method)

julia> Zygote.gradient(f, 2.0)
(1.0,)
ToucheSir commented 2 years ago

Can you run this on the latest version of Zygote and post the full stacktrace? It seems like you're on an older version, and just the error message is not enough for us to work with. Thanks!

Kolaru commented 2 years ago

On Zygote master and with julia 1.7.1 I get

julia> Zygote.gradient(f, 2.0)
ERROR: Compiling Tuple{Type{Dict}, Base.Generator{UnitRange{Int64}, var"#3#4"{Float64}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\reverse.jl:121
  [3] #Primal#23
    @ C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\reverse.jl:205 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\reverse.jl:330
  [5] _generate_pullback_via_decomposition(T::Type)
    @ Zygote C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\emit.jl:101
  [6] #s2770#1068
    @ C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\interface2.jl:28 [inlined]
  [7] var"#s2770#1068"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote .\none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core .\boot.jl:580
  [9] _pullback
    @ .\REPL[6]:2 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::typeof(f), args::Float64)
    @ Zygote C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\interface2.jl:0
 [11] pullback(f::Function, cx::Zygote.Context{false}, args::Float64)
    @ Zygote C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\interface.jl:44
 [12] pullback
    @ C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\interface.jl:42 [inlined]
 [13] gradient(f::Function, args::Float64)
    @ Zygote C:\Users\Kolaru\.julia\dev\Zygote\src\compiler\interface.jl:96
 [14] top-level scope
    @ REPL[7]:1
ToucheSir commented 2 years ago

Thanks. It looks like Dict(::Generator) hits a constructor which accepts any iterable. This constructor uses try/catch, so we'd have to add a rule for it. The codepath taken is somewhat tricky though, so if anyone wants to try this I'd recommend only dispatching for Generator to start.

In the meantime, another workaround if you know your key and value types up-front is to use the typed constructor instead:

a = Dict{Int,Int}(c => x for c in 1:3)

This bypasses the function with the try/catch and should be slightly faster to boot.