Open kishore-nori opened 1 year ago
Just to update, the following variation MWE where we loop over all the keys, is a work around. (So the problem is with the unavailability of rules and methods for Base.ValueIterator
, which is invoked in the above methods)
function mwe(x::Vector)
y = x.^2
collection = Dict(:a => x, :b => y)
s = zero(eltype(x))
for k in keys(collection)
s += sum(collection[k])
end
s
end
x = rand(3)
Zygote.gradient(mwe, x) # works!
Edit: I realised this is not general enough, for example, if each of the value of Dict
has different eltype
, then this is probably not a good idea.
After some trial and error, I have a generic form of the above work around, for which Zygote.gradient
works,
function mwe_generic(x::Vector)
y = x.^2
collection = Dict(:a => x, :b => y)
s = zero(first(values(collection))[1])
for k in keys(collection)
@inbounds s += sum(collection[k])
end
s
end
x = rand(3)
Zygote.gradient(mwe_generic,x) # works! :)
But it is good to have methods
and adjoint
for Base.ValueIterator
for the original MFE to work!
The above workaround unfortunately doesn't work for IdDict
, seems like it is hitting a ccall
which Zygote doesn't propagate through, see the following:
function mfe_IdDict(x::Vector)
y = x.^2
collection = IdDict(:a => x, :b => y)
s = zero(first(values(collection))[1])
for k in keys(collection)
@inbounds s += sum(collection[k])
end
s
end
julia> Zygote.gradient(mfe_IdDict,x)
ERROR: Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_nextind), UInt64, svec(Any, UInt64), 0, :(:ccall), %2, %5, %4)).
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] Pullback
@ ./iddict.jl:143 [inlined]
[3] (::Zygote.Pullback{Tuple{typeof(Base._oidd_nextind), Vector{Any}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(Base.cconvert), Type{UInt64}, Int64}, Tuple{Zygote.ZBack{Zygote.var"#convert_pullback#325"}}}, Zygote.Pullback{Tuple{typeof(reinterpret), Type{Int64}, UInt64}, Tuple{Zygote.Pullback{Tuple{Core.IntrinsicFunction, Type{Int64}, UInt64}, Tuple{Core.IntrinsicFunction}}}}, Zygote.Pullback{Tuple{typeof(Base.unsafe_convert), Type{UInt64}, UInt64}, Tuple{}}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[4] Pullback
@ ./iddict.jl:146 [inlined]
[5] (::Zygote.Pullback{Tuple{typeof(iterate), IdDict{Symbol, Vector{Float64}}, Int64}, Any})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[6] #287
@ ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
[7] (::Zygote.var"#2139#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{typeof(iterate), IdDict{Symbol, Vector{Float64}}, Int64}, Any}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
[8] Pullback
@ ./abstractdict.jl:64 [inlined]
[9] (::Zygote.Pullback{Tuple{typeof(iterate), Base.KeySet{Symbol, IdDict{Symbol, Vector{Float64}}}, Int64}, Any})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[10] Pullback
@ ./REPL[6]:7 [inlined]
[11] (::Zygote.Pullback{Tuple{typeof(mfe_IdDict), Vector{Float64}}, Any})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(mfe_IdDict), Vector{Float64}}, Any}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
[13] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97
Hi @ToucheSir, are there plans to make Zygote work with IdDict
? (should I open a different issue? I haven't found any IdDict
related issue in issues section here.)
There are no plans to make Zygote work better with any kind of Dict, but only because there is no dev capacity to do so. Hence why I added the above labels. Dicts are perhaps one of the trickiest types to create new functionality/fix bugs for in Zygote, but if any brave soul wants to try I'd be happy to guide them.
Hi,
I encountered the following errors, when working with functions based on Dictionaries, the following are the Minimum Failing Examples (MFEs) and my naive attempts: (They seem to require some methods and
adjoint
s for theBase.ValueIterator
type)The above results in the following error:
Since the above asks for a
size(::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
and realising that the methodlength(::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
exists, I tried adding the following methodwhich I don't know if it is the right way to go ahead, but, makes the forward mode, I guess, error free, but now the
Zygote.gradient
requests for anadjoint
, see the following updated error:independent of the above, having the following alternative MFE,
throws the same
Need an adjoint
error as the above:I would be happy to know, if this is fixable by writing an
adjoint
that the error requests or if there is work around for this issue. Thank you!