FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

keyword argument splatting can cause gc_preserve_end issues? #664

Closed ChrisRackauckas closed 4 years ago

ChrisRackauckas commented 4 years ago
using Zygote
f(args...;sensealg=nothing,kwargs...) = g(sensealg,args...;kwargs...)
g(args...;x=1,save_idxs=Colon(),kwargs...) = x[save_idxs]
Zygote.gradient(x->sum(f(;x=x,save_idxs=1:1)),ones(2))
Can't differentiate gc_preserve_end expression
error(::String) at error.jl:33
getindex at essentials.jl:591 [inlined]
(::typeof(∂(getindex)))(::Nothing) at interface2.jl:0
typejoin at promotion.jl:103 [inlined]
(::typeof(∂(typejoin)))(::Nothing) at interface2.jl:0
_promote_typejoin at promotion.jl:131 [inlined]
promote_typejoin at promotion.jl:130 [inlined]
(::typeof(∂(promote_typejoin)))(::Nothing) at interface2.jl:0
_compute_eltype at tuple.jl:117 [inlined]
(::typeof(∂(_compute_eltype)))(::Nothing) at interface2.jl:0
eltype at tuple.jl:110 [inlined]
eltype at namedtuple.jl:145 [inlined]
Pairs at iterators.jl:169 [inlined]
pairs at iterators.jl:226 [inlined]
(::typeof(∂(pairs)))(::NamedTuple{(:data, :itr),Tuple{NamedTuple{(:x, :save_idxs),Tuple{Array{Float64,1},Nothing}},Nothing}}) at interface2.jl:0
f at test.jl:2 [inlined]
(::typeof(∂(f##kw)))(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at interface2.jl:0
#78 at test.jl:4 [inlined]
(::typeof(∂(#78)))(::Float64) at interface2.jl:0
(::Zygote.var"#36#37"{typeof(∂(#78))})(::Float64) at interface.jl:46
gradient(::Function, ::Array{Float64,1}, ::Vararg{Array{Float64,1},N} where N) at interface.jl:55
top-level scope at test.jl:4

Interestingly, the save_idxs=1:1 thing seems to be required in order for this to trigger.

DhairyaLGandhi commented 4 years ago

Interesting, is that a new occurrence or have we seen something like this before?

Presumably ignoreing GC internal functions would get around this, which seems fine as a fix. It'd be good to understand why that comes up if it's a regression

ChrisRackauckas commented 4 years ago

https://github.com/FluxML/Zygote.jl/issues/633 might be the same issue, but it's hard to figure out exactly what "it" is.

DhairyaLGandhi commented 4 years ago

Seems to happen when a kw isn't declared beforehand.

julia> f(args...;sensealg=nothing,x=1,save_idxs=Colon(), kwargs...) = g(sensealg,args...;x=x, save_idxs = save_idxs, kwargs...)
f (generic function with 1 method)

julia> g(args...;x=1,save_idxs=Colon(),kwargs...) = x[save_idxs]
g (generic function with 1 method)

julia> Zygote.gradient(x->sum(f(;x=x,save_idxs=1:1, extrakw= 0)),ones(2))
([1.0, 0.0],)
cossio commented 4 years ago

Also https://github.com/FluxML/Zygote.jl/issues/584

DhairyaLGandhi commented 4 years ago
julia> f(args...;a = nothing, b = 1, idx = Colon(), kwargs...) = g(a, args..., b = b, idx = idx, kwargs...)
f (generic function with 1 method)

julia> g(args...;b = 1, idx = Colon(), kwargs...) = b[idx]
g (generic function with 1 method)

julia> Zygote.gradient(x->sum(f(;b=x,idx=1:1)),ones(2))
ERROR: Need an adjoint for constructor Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}. Gradient is of type Tuple{}
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.Jnew{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing,false})(::Tuple{}) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/lib/lib.jl:306
 [3] (::Zygote.var"#381#back#195"{Zygote.Jnew{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing,false}})(::Tuple{}) at /Users/dhairyagandhi/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] Pairs at ./iterators.jl:169 [inlined]
 [5] (::typeof(∂(Base.Iterators.Pairs)))(::Tuple{}) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface2.jl:0
 [6] pairs at ./iterators.jl:226 [inlined]
 [7] (::typeof(∂(pairs)))(::Tuple{}) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface2.jl:0
 [8] #f at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface2.jl:0 [inlined]
 [9] (::typeof(∂(#f)))(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface2.jl:0
 [10] #7 at ./REPL[5]:1 [inlined]
 [11] (::typeof(∂(#7)))(::Float64) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#38#39"{typeof(∂(#7))})(::Float64) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface.jl:46
 [13] gradient(::Function, ::Array{Float64,1}) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/compiler/interface.jl:55
 [14] top-level scope at REPL[5]:1

shows the error a bit more clearly. Or more simple MWE:

julia> f(x; id = 1) = pairs(x)[id]
f (generic function with 1 method)

Defining an adjoint rule for Base.Iterators.Pairs should take care of this.

DhairyaLGandhi commented 4 years ago

Zygote doesn't attempt to differentiate kwargs, but here is a quick and dirty adjoint that I threw together, just to check that hitting that method was the case, and that seems to work, but we'll want a more general adjoint here.

julia> Zygote.@adjoint function Base.pairs(x::T) where T
         y = Base.pairs(x)
         back(dx::NamedTuple) = (dx.data,)
         function back(dx::Dict)
           T <: AbstractDict && return (dx,)
           z = zero(x)
           for (k,v) in dx
             z[k] = v
           end
           (z,)
         end
         y, back
       end
ChrisRackauckas commented 4 years ago

Seems to fix the issue for me.

DhairyaLGandhi commented 4 years ago

Might be better if we restrict the type to just a NamedTuple since that would get rid of the original error and we wouldn't have to deal with the wonky dict part.

@adjoint function Base.pairs(x::NamedTuple)
  Base.pairs(x), Δ -> (Δ.data,)
end
DhairyaLGandhi commented 4 years ago

I don't think we should add adjoints that access internal fields directly for such things.

ChrisRackauckas commented 4 years ago

The adjoint on pairs seems sufficient for my applications. Is there any chance this could get in a tagged release by the end of the day?

DhairyaLGandhi commented 4 years ago

Sorry, it was like 2am when I did this and missed this checking messages today.

I'd like some review on this before merging since I'm not sure if it's a good idea to access member fields in an adjoint.

Side note, it might just be something handled better in the compiler.

DhairyaLGandhi commented 4 years ago

Should be fixed on master