FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

missing rules for `repeat` #906

Open CarloLucibello opened 3 years ago

CarloLucibello commented 3 years ago

Zygote is missing some repeat adjoints;

# this is OK
julia> gradient(x -> sum(repeat(x, outer=(2,2,2))), reshape(1:8, 2,2,2))
([8 8; 8 8]

[8 8; 8 8],)

# missing rule
julia> gradient(x -> sum(repeat(x, 2, 2, 2)), reshape(1:8, 2,2,2))
ERROR: Mutating arrays is not supported
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#372#373")(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/lib/array.jl:58
  [3] (::Zygote.var"#2249#back#374"{Zygote.var"#372#373"})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./abstractarraymath.jl:365 [inlined]
  [5] (::typeof(∂(repeat_outer)))(Δ::FillArrays.Fill{Int64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./abstractarraymath.jl:327 [inlined]
  [7] Pullback
    @ ./abstractarraymath.jl:269 [inlined]
  [8] (::typeof(∂(#repeat#1)))(Δ::FillArrays.Fill{Int64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
  [9] Pullback
    @ ./abstractarraymath.jl:267 [inlined]
 [10] (::typeof(∂(repeat##kw)))(Δ::FillArrays.Fill{Int64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [11] Pullback
    @ ./abstractarraymath.jl:224 [inlined]
 [12] (::typeof(∂(repeat)))(Δ::FillArrays.Fill{Int64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [13] (::Zygote.var"#151#152"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing, Nothing}}, typeof(∂(repeat))})(Δ::FillArrays.Fill{Int64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/lib/lib.jl:191
 [14] #1682#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [15] Pullback
    @ ./REPL[13]:1 [inlined]
 [16] (::Zygote.var"#41#42"{typeof(∂(#27))})(Δ::Int64)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:40
 [17] gradient(f::Function, args::Base.ReshapedArray{Int64, 3, UnitRange{Int64}, Tuple{}})
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:49
 [18] top-level scope
    @ REPL[13]:1
willtebbutt commented 3 years ago

Good catch. @CarloLucibello would you mind moving this over the ChainRules?

CarloLucibello commented 3 years ago

repeat adjoints are still in Zygote https://github.com/FluxML/Zygote.jl/blob/956cbcf3c572c0eb09c146189bb38b1b434634ff/src/lib/array.jl#L130 not sure why they were ported to ChainRules, probably just lack of time