Closed ChrisRackauckas closed 4 years ago
Interesting, is that a new occurrence or have we seen something like this before?
Presumably ignore
ing GC internal functions would get around this, which seems fine as a fix. It'd be good to understand why that comes up if it's a regression
https://github.com/FluxML/Zygote.jl/issues/633 might be the same issue, but it's hard to figure out exactly what "it" is.
Seems to happen when a kw isn't declared beforehand.
julia> f(args...;sensealg=nothing,x=1,save_idxs=Colon(), kwargs...) = g(sensealg,args...;x=x, save_idxs = save_idxs, kwargs...)
f (generic function with 1 method)
julia> g(args...;x=1,save_idxs=Colon(),kwargs...) = x[save_idxs]
g (generic function with 1 method)
julia> Zygote.gradient(x->sum(f(;x=x,save_idxs=1:1, extrakw= 0)),ones(2))
([1.0, 0.0],)
julia> f(args...;a = nothing, b = 1, idx = Colon(), kwargs...) = g(a, args..., b = b, idx = idx, kwargs...)
f (generic function with 1 method)
julia> g(args...;b = 1, idx = Colon(), kwargs...) = b[idx]
g (generic function with 1 method)
julia> Zygote.gradient(x->sum(f(;b=x,idx=1:1)),ones(2))
ERROR: Need an adjoint for constructor Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}. Gradient is of type Tuple{}
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::Zygote.Jnew{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing,false})(::Tuple{}) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/lib/lib.jl:306
[3] (::Zygote.var"#381#back#195"{Zygote.Jnew{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing,false}})(::Tuple{}) at /Users/dhairyagandhi/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[4] Pairs at ./iterators.jl:169 [inlined]
[5] (::typeof(∂(Base.Iterators.Pairs)))(::Tuple{}) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface2.jl:0
[6] pairs at ./iterators.jl:226 [inlined]
[7] (::typeof(∂(pairs)))(::Tuple{}) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface2.jl:0
[8] #f at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface2.jl:0 [inlined]
[9] (::typeof(∂(#f)))(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface2.jl:0
[10] #7 at ./REPL[5]:1 [inlined]
[11] (::typeof(∂(#7)))(::Float64) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface2.jl:0
[12] (::Zygote.var"#38#39"{typeof(∂(#7))})(::Float64) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface.jl:46
[13] gradient(::Function, ::Array{Float64,1}) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface.jl:55
[14] top-level scope at REPL[5]:1
shows the error a bit more clearly. Or more simple MWE:
julia> f(x; id = 1) = pairs(x)[id]
f (generic function with 1 method)
Defining an adjoint rule for Base.Iterators.Pairs
should take care of this.
Zygote doesn't attempt to differentiate kwargs
, but here is a quick and dirty adjoint that I threw together, just to check that hitting that method was the case, and that seems to work, but we'll want a more general adjoint here.
julia> Zygote.@adjoint function Base.pairs(x::T) where T
y = Base.pairs(x)
back(dx::NamedTuple) = (dx.data,)
function back(dx::Dict)
T <: AbstractDict && return (dx,)
z = zero(x)
for (k,v) in dx
z[k] = v
end
(z,)
end
y, back
end
Seems to fix the issue for me.
Might be better if we restrict the type to just a NamedTuple
since that would get rid of the original error and we wouldn't have to deal with the wonky dict part.
@adjoint function Base.pairs(x::NamedTuple)
Base.pairs(x), Δ -> (Δ.data,)
end
I don't think we should add adjoints that access internal fields directly for such things.
The adjoint on pairs seems sufficient for my applications. Is there any chance this could get in a tagged release by the end of the day?
Sorry, it was like 2am when I did this and missed this checking messages today.
I'd like some review on this before merging since I'm not sure if it's a good idea to access member fields in an adjoint.
Side note, it might just be something handled better in the compiler.
Should be fixed on master
Interestingly, the
save_idxs=1:1
thing seems to be required in order for this to trigger.