JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
251 stars 61 forks source link

Errror in accumulate when I have one argument as a tuple #664

Open pevnak opened 6 months ago

pevnak commented 6 months ago

Hello,

I have been for educational purposes implementing RNN by hand and wanted to be fancy and use accumulate instead of recursion or for rule. But I run into an error, when one of the operands in accumulate is tuple. A have carved out an MWE, which would look like this

using Zygote

x = [randn(Float32, 2) for i in 1:3]
h = randn(Float32, 2)

function f(α, h, x)
    o = accumulate(x, init = h) do h, x
        α * h + x
    end
end

function g(α, h, x)
    o = accumulate(x, init = (h, x[1])) do (h,_),x
        (α * h + x, x)
    end
    first.(o)
end

gradient(α -> sum(sum(g(α, h, x))), 1f0)[1]
gradient(α -> sum(sum(f(α, h, x))), 1f0)[1]

While computing gradient of f succeeds, computing gradient of g crashes with

julia> gradient(α -> sum(sum(g(α, h, x))), 1f0)[1]
ERROR: MethodError: no method matching construct(::Type{Any}, ::Tuple{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, ChainRulesCore.NoTangent})

Closest candidates are:
  construct(::Type{T}, ::T) where T<:Tuple
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_types/structural_tangent.jl:251
  construct(::Type{T}, ::NamedTuple{L}) where {T, L}
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_types/structural_tangent.jl:235

Stacktrace:
  [1] +(a::ChainRulesCore.Tangent{Tuple{…}, Tuple{…}}, d::ChainRulesCore.Tangent{Any, Tuple{…}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_arithmetic.jl:142
  [2] (::ChainRules.var"#1699#1702")(::Tuple{…}, ::Tuple{…})
    @ ChainRules ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:541
  [3] iterate(itr::Base.Iterators.Accumulate)
    @ Base.Iterators ./iterators.jl:589 [inlined]
  [4] collect_to!
    @ ./array.jl:892 [inlined]
  [5] collect_to_with_first!
    @ ./array.jl:870 [inlined]
  [6] _collect(c::Any, itr::Any, ::Base.EltypeUnknown, isz::Union{Base.HasLength, Base.HasShape})
    @ Base ./array.jl:864 [inlined]
  [7] collect(itr::Base.Generator)
    @ Base ./array.jl:759 [inlined]
  [8] #accumulate#893
    @ ./accumulate.jl:281 [inlined]
  [9] accumulate
    @ ./accumulate.jl:278 [inlined]
 [10] (::ChainRules.var"#decumulate#1701"{…})(dy::Vector{…})
    @ ChainRules ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:540
 [11] ZBack
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
 [12] (::Zygote.var"#kw_zpullback#53"{ChainRules.var"#decumulate#1701"{…}})(dy::Vector{Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:237
 [13] g
    @ ./REPL[43]:2 [inlined]
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{FillArrays.Fill{…}, 1, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [15] #53
    @ ./REPL[44]:1 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [17] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [18] gradient(f::Function, args::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148
 [19] top-level scope
    @ REPL[44]:1
Some type information was truncated. Use `show(err)` to see complete types.

Julia and environment

julia> versioninfo()
Julia Version 1.10.0-rc2
Commit dbb9c46795b (2023-12-03 15:25 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (x86_64-apple-darwin22.4.0)
  CPU: 8 × Intel(R) Core(TM) i5-8279U CPU @ 2.40GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
  Threads: 1 on 8 virtual cores

(tmp) pkg> st
Status `/private/tmp/Project.toml`
  [082447d4] ChainRules v1.63.0
  [d360d2e6] ChainRulesCore v1.21.1
  [26cc04aa] FiniteDifferences v0.12.31
  [587475ba] Flux v0.14.11
  [3bd65402] Optimisers v0.3.2
  [eeda0dda] SafeTensors v1.0.0
  [2913bbd2] StatsBase v0.34.2
  [e88e6eb3] Zygote v0.6.69

Thanks for help

nmheim commented 6 months ago

Zygote is constructing tangents that enter the decumulate pullback via wrap_chainrules_output. in this case its hitting the method for Union{Tuple,NamedTuple} which is interesting, because I think it should be using the method for Tuple.

I think this could be fixed by making sure wrap_chainrules_output returns a StructuralTangent... or at least if in zygote I do:

@inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple})
  xp = map(wrap_chainrules_input, dxs)
  # This produces Tangent{Any} since it does not get to see the primal, `x`.
  # ChainRulesCore.Tangent{Any, typeof(xp)}(xp) -- comment this out and replace by line below
  ChainRulesCore.StructuralTangent{typeof(xp)}(xp)
end

things seem to work out

mcabbott commented 6 months ago

Same error with https://github.com/JuliaDiff/ChainRules.jl/pull/569, FWIW.

Not certain this is relevant, but notice the similarity to this:

julia> accumulate(=>, (1,2,3))
(1, 1 => 2, (1 => 2) => 3)

julia> accumulate(=>, [1,2,3])
ERROR: MethodError: Cannot `convert` an object of type Int64 to an object of type Pair{Int64, Int64}

and that this gradient works with x::Tuple:

julia> gradient(α -> sum(sum(g(α, h, Tuple(x)))), 1f0)[1]
15.059713f0

julia> gradient(α -> sum(sum(g(α, h, x))), 1f0)[1]  # with x::Vector as above
ERROR: MethodError: no method matching construct(::Type{Any}, ::Tuple{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, ChainRulesCore.NoTangent})