FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Error with gradient of function based on Dictionary #1421

Open kishore-nori opened 1 year ago

kishore-nori commented 1 year ago

Hi,

I encountered the following errors, when working with functions based on Dictionaries, the following are the Minimum Failing Examples (MFEs) and my naive attempts: (They seem to require some methods and adjoints for the Base.ValueIterator type)

using Zygote 

function mfe1(x::Vector)
  y = x.^2
  collection = Dict(:a => x, :b => y)
  sum(map(sum,values(collection)))
end

x = rand(3)

Zygote.gradient(mfe1, x)

The above results in the following error:

ERROR: MethodError: no method matching size(::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
Closest candidates are:
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}) at ~/julia-src/julia-1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/qr.jl:581
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}, ::Integer) at ~/julia-src/julia-1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/qr.jl:580
  size(::Union{LinearAlgebra.Cholesky, LinearAlgebra.CholeskyPivoted}) at ~/julia-src/julia-1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/cholesky.jl:514
  ...
Stacktrace:
  [1] axes
    @ ./abstractarray.jl:95 [inlined]
  [2] _tryaxes(x::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:188
  [3] map
    @ ./tuple.jl:221 [inlined]
  [4] ∇map(cx::Zygote.Context{false}, f::typeof(sum), args::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:203
  [5] _pullback(cx::Zygote.Context{false}, #unused#::typeof(collect), g::Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:244
  [6] _pullback
    @ ./abstractarray.jl:2961 [inlined]
  [7] _pullback(::Zygote.Context{false}, ::typeof(map), ::typeof(sum), ::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [8] _pullback
    @ ./REPL[2]:4 [inlined]
  [9] _pullback(ctx::Zygote.Context{false}, f::typeof(mfe1), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [10] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:44
 [11] pullback
    @ ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:42 [inlined]
 [12] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:96

Since the above asks for a size(::Base.ValueIterator{Dict{Symbol, Vector{Float64}}}) and realising that the method length(::Base.ValueIterator{Dict{Symbol, Vector{Float64}}}) exists, I tried adding the following method

Base.size(v::Union{Base.KeySet,Base.ValueIterator}) = (length(v.dict),)

which I don't know if it is the right way to go ahead, but, makes the forward mode, I guess, error free, but now the Zygote.gradient requests for an adjoint, see the following updated error:

ERROR: Need an adjoint for constructor Base.ValueIterator{Dict{Symbol, Vector{Float64}}}. Gradient is of type Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] (::Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:330
  [3] (::Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
  [4] Pullback
    @ ./abstractdict.jl:48 [inlined]
  [5] (::Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./abstractdict.jl:48 [inlined]
  [7] (::Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [8] Pullback
    @ ./abstractdict.jl:131 [inlined]
  [9] (::Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./REPL[2]:4 [inlined]
 [11] (::Zygote.Pullback{Tuple{typeof(mfe1), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.Pullback{Tuple{typeof(map), typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator}, typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)}}, typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{typeof(sum)}, typeof(sum)}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)}, Nothing, false}}}}}}, Zygote.var"#collect_pullback#715"{Zygote.var"#map_back#677"{typeof(sum), 1, Tuple{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Tuple{Base.OneTo{Int64}}}, Vector{Tuple{Float64, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}}}}, Nothing}}}, Zygote.var"#3863#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float64}}}, Zygote.var"#1892#back#157"{Zygote.var"#153#156"}, Zygote.Pullback{Tuple{Type{Dict}, Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Dict{Symbol, Vector{Float64}}}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float64}}, Tuple{}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(mfe1), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.Pullback{Tuple{typeof(map), typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator}, typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)}}, typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{typeof(sum)}, typeof(sum)}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)}, Nothing, false}}}}}}, Zygote.var"#collect_pullback#715"{Zygote.var"#map_back#677"{typeof(sum), 1, Tuple{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Tuple{Base.OneTo{Int64}}}, Vector{Tuple{Float64, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}}}}, Nothing}}}, Zygote.var"#3863#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float64}}}, Zygote.var"#1892#back#157"{Zygote.var"#153#156"}, Zygote.Pullback{Tuple{Type{Dict}, Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Dict{Symbol, Vector{Float64}}}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float64}}, Tuple{}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97

independent of the above, having the following alternative MFE,

function mfe2(x::Vector)
  y = x.^2
  collection = Dict(:a => x, :b => y)
  v = vcat(values(collection)...)
  sum(v)
end

throws the same Need an adjoint error as the above:

ERROR: Need an adjoint for constructor Base.ValueIterator{Dict{Symbol, Vector{Float64}}}. Gradient is of type Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] (::Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:330
  [3] (::Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
  [4] Pullback
    @ ./abstractdict.jl:48 [inlined]
  [5] (::Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./abstractdict.jl:48 [inlined]
  [7] (::Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [8] Pullback
    @ ./abstractdict.jl:131 [inlined]
  [9] (::Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./REPL[2]:4 [inlined]
 [11] (::Zygote.Pullback{Tuple{typeof(mfe2), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2139#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}, Tuple{Tuple{Int64}, Tuple{Int64}}, Val{1}}}}}, Zygote.var"#3863#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float64}}}, Zygote.var"#1892#back#157"{Zygote.var"#153#156"}, Zygote.Pullback{Tuple{Type{Dict}, Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Dict{Symbol, Vector{Float64}}}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float64}}, Tuple{}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(mfe2), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2139#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}, Tuple{Tuple{Int64}, Tuple{Int64}}, Val{1}}}}}, Zygote.var"#3863#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float64}}}, Zygote.var"#1892#back#157"{Zygote.var"#153#156"}, Zygote.Pullback{Tuple{Type{Dict}, Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Dict{Symbol, Vector{Float64}}}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float64}}, Tuple{}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97

I would be happy to know, if this is fixable by writing an adjoint that the error requests or if there is work around for this issue. Thank you!

kishore-nori commented 1 year ago

Just to update, the following variation MWE where we loop over all the keys, is a work around. (So the problem is with the unavailability of rules and methods for Base.ValueIterator, which is invoked in the above methods)

function mwe(x::Vector)
  y = x.^2
  collection = Dict(:a => x, :b => y)
  s = zero(eltype(x))
  for k in keys(collection)
    s += sum(collection[k])
  end
  s
end

x = rand(3)

Zygote.gradient(mwe, x) # works! 

Edit: I realised this is not general enough, for example, if each of the value of Dict has different eltype, then this is probably not a good idea.

kishore-nori commented 1 year ago

After some trial and error, I have a generic form of the above work around, for which Zygote.gradient works,

function mwe_generic(x::Vector)
  y = x.^2
  collection = Dict(:a => x, :b => y)
  s = zero(first(values(collection))[1])
  for k in keys(collection)
    @inbounds s += sum(collection[k])
  end
  s
end

x = rand(3)

Zygote.gradient(mwe_generic,x) # works! :)

But it is good to have methods and adjoint for Base.ValueIterator for the original MFE to work!

kishore-nori commented 1 year ago

The above workaround unfortunately doesn't work for IdDict, seems like it is hitting a ccall which Zygote doesn't propagate through, see the following:

function mfe_IdDict(x::Vector)
  y = x.^2
  collection = IdDict(:a => x, :b => y)
  s = zero(first(values(collection))[1])
  for k in keys(collection)
    @inbounds s += sum(collection[k])
  end
  s
end

julia> Zygote.gradient(mfe_IdDict,x)
ERROR: Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_nextind), UInt64, svec(Any, UInt64), 0, :(:ccall), %2, %5, %4)).
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] Pullback
    @ ./iddict.jl:143 [inlined]
  [3] (::Zygote.Pullback{Tuple{typeof(Base._oidd_nextind), Vector{Any}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(Base.cconvert), Type{UInt64}, Int64}, Tuple{Zygote.ZBack{Zygote.var"#convert_pullback#325"}}}, Zygote.Pullback{Tuple{typeof(reinterpret), Type{Int64}, UInt64}, Tuple{Zygote.Pullback{Tuple{Core.IntrinsicFunction, Type{Int64}, UInt64}, Tuple{Core.IntrinsicFunction}}}}, Zygote.Pullback{Tuple{typeof(Base.unsafe_convert), Type{UInt64}, UInt64}, Tuple{}}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [4] Pullback
    @ ./iddict.jl:146 [inlined]
  [5] (::Zygote.Pullback{Tuple{typeof(iterate), IdDict{Symbol, Vector{Float64}}, Int64}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
  [6] #287
    @ ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
  [7] (::Zygote.var"#2139#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{typeof(iterate), IdDict{Symbol, Vector{Float64}}, Int64}, Any}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
  [8] Pullback
    @ ./abstractdict.jl:64 [inlined]
  [9] (::Zygote.Pullback{Tuple{typeof(iterate), Base.KeySet{Symbol, IdDict{Symbol, Vector{Float64}}}, Int64}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./REPL[6]:7 [inlined]
 [11] (::Zygote.Pullback{Tuple{typeof(mfe_IdDict), Vector{Float64}}, Any})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(mfe_IdDict), Vector{Float64}}, Any}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97

Hi @ToucheSir, are there plans to make Zygote work with IdDict? (should I open a different issue? I haven't found any IdDict related issue in issues section here.)

ToucheSir commented 1 year ago

There are no plans to make Zygote work better with any kind of Dict, but only because there is no dev capacity to do so. Hence why I added the above labels. Dicts are perhaps one of the trickiest types to create new functionality/fix bugs for in Zygote, but if any brave soul wants to try I'd be happy to guide them.