FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

`intersect(itr, itrs...)` does not work with Zygote #154

Open GiggleLiu opened 5 years ago

GiggleLiu commented 5 years ago

I got stackoverflowerror when calling a function that contains intersect

julia> gradient(x->(intersect((1,2), (2,3)); sum(x)), a)
ERROR: StackOverflowError:
Stacktrace:
 [1] _forward(::Zygote.Context, ::typeof(collect), ::Base.Generator{Base.Iterators.Filter{getfield(Base, Symbol("##83#84")){typeof(in),typeof(pop!),Set{Int64}},Tuple{Int64,Int64}},getfield(Base, Symbol("##85#86"))}) at /home/leo/.julia/dev/Zygote/src/lib/array.jl:92
 [2] map at ./abstractarray.jl:2044 [inlined]
 [3] _forward(::Zygote.Context, ::typeof(map), ::getfield(Base, Symbol("##85#86")), ::Base.Iterators.Filter{getfield(Base, Symbol("##83#84")){typeof(in),typeof(pop!),Set{Int64}},Tuple{Int64,Int64}}) at /home/leo/.julia/dev/Zygote/src/compiler/interface2.jl:0
 ... (the last 3 lines are repeated 32442 more times)
 [97330] _forward(::Zygote.Context, ::typeof(collect), ::Base.Generator{Base.Iterators.Filter{getfield(Base, Symbol("##83#84")){typeof(in),typeof(pop!),Set{Int64}},Tuple{Int64,Int64}},getfield(Base, Symbol("##85#86"))}) at /home/leo/.julia/dev/Zygote/src/lib/array.jl:92
 [97331] _forward(::Zygote.Context, ::typeof(Base.vectorfilter), ::getfield(Base, Symbol("##83#84")){typeof(in),typeof(pop!),Set{Int64}}, ::Tuple{Int64,Int64}) at ./array.jl:2393
 [97332] _shrink at ./array.jl:2397 [inlined]
 [97333] _forward(::Zygote.Context, ::typeof(Base._shrink), ::typeof(intersect!), ::Tuple{Int64,Int64}, ::Tuple{Tuple{Int64,Int64}}) at /home/leo/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [97334] intersect at ./array.jl:2400 [inlined]
 [97335] _forward(::Zygote.Context, ::typeof(intersect), ::Tuple{Int64,Int64}, ::Tuple{Int64,Int64}) at /home/leo/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [97336] #167 at ./none:1 [inlined]
 [97337] _forward(::Zygote.Context, ::getfield(Main, Symbol("##167#168")), ::Array{Float64,2}) at /home/leo/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [97338] _forward(::Function, ::Array{Float64,2}) at /home/leo/.julia/dev/Zygote/src/compiler/interface.jl:31
 [97339] forward(::Function, ::Array{Float64,2}) at /home/leo/.julia/dev/Zygote/src/compiler/interface.jl:37
 [97340] gradient(::Function, ::Array{Float64,2}) at /home/leo/.julia/dev/Zygote/src/compiler/interface.jl:46

This code fails due to the forward of collect runs into a loop call.

 [97330] function _forward(cx::Context, ::typeof(collect), g::Base.Generator)
  y, back = _forward(cx, map, g.f, g.iter)
...

# then calls into
[3] @generated function _forward(ctx::Context, f, args...)
  T = Tuple{f,args...}

# then
[2] map(f, A) = collect(Generator(f,A))   # again collect, then dead

This issue should be taken seriously since any program contain intersect and setdiff would fail due to this bug. @MikeInnes Do you have any idea to fix it?

MikeInnes commented 5 years ago

This definition just needs to be more general. We could probably just remove the AbstractArray restriction, but need to be a bit careful to that map over tuples is still fast.

GiggleLiu commented 5 years ago

Thanks, I remove this definition, there is no infinite loop any more, but intersect and setdiff is still not working since it relies on #mutable branch.

but need to be a bit careful to that map over tuples is still fast.

Since you mensioned the performance issue, I didn't get it. So I submitted a WIP PR #156 to accept more tests about performance.

MikeInnes commented 5 years ago

Generally things like map(identity, (1, "foo")) should infer, and so should its adjoints. The PR will break that as it currently stands, though it's fine if we fix that later.

ToucheSir commented 2 years ago

The internals referred to here no longer exist, but intersect(::Tuple, ::Tuple) doesn't work either. Here's the current state of things:

ulia> gradient(x->(intersect((1,2), (2,3)); sum(x)), a)
ERROR: MethodError: no method matching getindex(::Dict{Any, Any})
Closest candidates are:
  getindex(::Dict{K, V}, ::Any) where {K, V} at ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/base/dict.jl:479
  getindex(::AbstractDict, ::Any) at ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/base/abstractdict.jl:496
  getindex(::AbstractDict, ::Any, ::Any, ::Any...) at ~/.julia/juliaup/julia-1.7.1+0~x64/share/julia/base/abstractdict.jl:506
Stacktrace:
  [1] (::Zygote.var"#225#226"{Symbol, Dict{Any, Any}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/lib/lib.jl:277
  [2] (::Zygote.var"#1775#back#227"{Zygote.var"#225#226"{Symbol, Dict{Any, Any}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [3] Pullback
    @ ./Base.jl:43 [inlined]
  [4] (::typeof(∂(setproperty!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
  [5] Pullback
    @ ./dict.jl:634 [inlined]
  [6] (::typeof(∂(_delete!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
  [7] Pullback
    @ ./dict.jl:664 [inlined]
  [8] (::typeof(∂(delete!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
  [9] Pullback
    @ ./set.jl:68 [inlined]
 [10] (::typeof(∂(delete!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [11] Pullback
    @ ./abstractset.jl:444 [inlined]
 [12] (::typeof(∂(mapfilter)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [13] Pullback
    @ ./abstractset.jl:439 [inlined]
 [14] (::typeof(∂(unsafe_filter!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [15] Pullback
    @ ./set.jl:422 [inlined]
 [16] Pullback
    @ ./abstractset.jl:146 [inlined]
 [17] Pullback
    @ ./abstractset.jl:147 [inlined]
 [18] (::typeof(∂(intersect!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [19] #213
    @ ~/.julia/packages/Zygote/umM0L/src/lib/lib.jl:203 [inlined]
 [20] #1752#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [21] Pullback
    @ ./array.jl:2632 [inlined]
 [22] (::typeof(∂(_shrink)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [23] Pullback
    @ ./array.jl:2636 [inlined]
 [24] Pullback
    @ ./REPL[4]:1 [inlined]
 [25] (::typeof(∂(#3)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface2.jl:0
 [26] (::Zygote.var"#57#58"{typeof(∂(#3))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:41
 [27] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/umM0L/src/compiler/interface.jl:76
 [28] top-level scope
    @ REPL[4]:1

I've repurposed this issue as a tracker for intersect support.