JuliaDiff / Diffractor.jl

Next-generation AD
MIT License
435 stars 31 forks source link

Failure with ComposedFunction `∘` #67

Open mcabbott opened 2 years ago

mcabbott commented 2 years ago

I'm pretty confident this worked in September, but have no idea whether changes here or in ChainRules broke it: Edit -- the change is https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495, discussed there.

julia> Diffractor.gradient(cbrt, 1.23)
(0.29036348772107673,)

julia> Diffractor.gradient(identity∘cbrt, 1.23)
ERROR: ArgumentError: Tangent for the primal Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}} should be backed by a AbstractDict type, not by NamedTuple{(:data,), Tuple{ChainRulesCore.ZeroTangent}}.
Stacktrace:
  [1] _backing_error(P::Type, G::Type, E::Type)
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/qzYOG/src/tangent_types/tangent.jl:62
  [2] ChainRulesCore.Tangent{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, NamedTuple{(:data,), Tuple{ChainRulesCore.ZeroTangent}}}(backing::NamedTuple{(:data,), Tuple{ChainRulesCore.ZeroTangent}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/qzYOG/src/tangent_types/tangent.jl:33
  [3] (::Diffractor.var"#162#164"{Symbol, DataType})(Δ::ChainRulesCore.ZeroTangent)
    @ Diffractor ~/.julia/packages/Diffractor/HYuxt/src/stage1/generated.jl:308
  [4] (::Diffractor.EvenOddOdd{1, 1, Diffractor.var"#162#164"{Symbol, DataType}, Diffractor.var"#163#165"{Symbol}})(Δ::ChainRulesCore.ZeroTangent)
    @ Diffractor ~/.julia/packages/Diffractor/HYuxt/src/stage1/generated.jl:288
  [5] ∂⃖¹₁merge
    @ ./none:1
  [6] ∂⃖¹₁
    @ ./none:1
  [7] (::Diffractor.ApplyOdd{1, 1})(Δ::Float64)
    @ Diffractor ~/.julia/packages/Diffractor/HYuxt/src/stage1/generated.jl:371
  [8] ∂⃖¹₁ComposedFunction
    @ ./none:1
  [9] (::Diffractor.∇{ComposedFunction{typeof(identity), typeof(cbrt)}})(args::Float64)
    @ Diffractor ~/.julia/packages/Diffractor/HYuxt/src/interface.jl:122
 [10] Diffractor.∇(::Function, ::Float64)
    @ Diffractor ~/.julia/packages/Diffractor/HYuxt/src/interface.jl:128
 [11] top-level scope
    @ REPL[3]:1
mcabbott commented 2 years ago

BTW, if you disable that check with ChainRulesCore._backing_error(P,G,E) = nothing, then it fails by generating a Tuple and a Tangent for the same gradient. This case was not addressed in https://github.com/JuliaDiff/Diffractor.jl/pull/88

Stacktrace:
 [1] accum(a::Tangent{Tuple{typeof(identity), typeof(cbrt)}, Tuple{NoTangent, NoTangent}}, b::Tuple{NoTangent, ZeroTangent})
   @ Diffractor ~/.julia/dev/Diffractor/src/runtime.jl:4
 [2] (::Tuple{Diffractor.EvenOddOdd{1, 1, Diffractor.∂⃖getfield{2, 1}, Diffractor.var"#160#161"{Int64}}, ChainRules.var"#tail_pullback#1599"{Tuple{typeof(identity), typeof(cbrt)}}, Core.OpaqueClosure{Tuple{Any}, Tuple{ZeroTangent, Tuple{NoTangent}, Tuple{Any}, Tangent{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, NamedTuple{(:data,), Tuple{ZeroTangent}}}}}, ChainRules.var"#identity_pullback#1221", Nothing})(Δ::Any)
   @ Diffractor ./none:1
 [3] (::Tuple{Core.OpaqueClosure{Tuple{Any}, Tuple{ZeroTangent, Union{NoTangent, Tangent{ComposedFunction{typeof(identity), typeof(cbrt)}}}}}, Core.OpaqueClosure{Tuple{Any}, Union{}}, Nothing})(Δ::Any)
   @ Diffractor ./none:1
 [4] (::Diffractor.ApplyOdd{1, 1})(Δ::Float64)
   @ Diffractor ~/.julia/dev/Diffractor/src/stage1/generated.jl:373
 [5] (::Tuple{Core.OpaqueClosure{Tuple{Any}, Tuple{ZeroTangent}}, Core.OpaqueClosure{Tuple{Any}, Tuple{ZeroTangent, Any}}, Diffractor.EvenOddOdd{1, 1, Diffractor.tuple_back{2}, Diffractor.var"#176#177"}, Diffractor.ApplyOdd{1, 1}, Nothing})(Δ::Any)
   @ Diffractor ./none:1
 [6] (::Diffractor.∇{ComposedFunction{typeof(identity), typeof(cbrt)}})(args::Float64)