JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
436 stars 89 forks source link

`frule` for `sum` doesn't work for `Generator` #768

Open MasonProtter opened 10 months ago

MasonProtter commented 10 months ago

I think this is a ChainRules.jl problem, not a Diffractor.jl problem but I'm not 100% sure. Here's a MWE:

julia> using Diffractor: DiffractorForwardBackend; using AbstractDifferentiation: derivative, jacobian

julia> derivative(DiffractorForwardBackend(), x -> sum(x+i for i ∈ 1:100), 1.0)
ERROR: MethodError: no method matching mapfoldl(::typeof(identity), ::typeof(Base.add_sum), ::Base.Generator{UnitRange{Int64}, var"#7#9"{Float64}}; dims::Colon)

Closest candidates are:
  mapfoldl(::Any, ::Any, ::Any; init) got unsupported keyword argument "dims"
   @ Base reduce.jl:175
  mapfoldl(::F, ::R, ::StaticArraysCore.StaticArray; init) where {F, R} got unsupported keyword argument "dims"
   @ StaticArrays ~/.julia/packages/StaticArrays/eGKzB/src/mapreduce.jl:255

Stacktrace:
  [1] kwerr(::@NamedTuple{dims::Colon}, ::Function, ::Function, ::Function, ::Base.Generator{UnitRange{Int64}, var"#7#9"{Float64}})
    @ Base ./error.jl:165
  [2] mapreduce(f::Function, op::Function, itr::Base.Generator{UnitRange{Int64}, var"#7#9"{Float64}}; kw::@Kwargs{dims::Colon})
    @ Base ./reduce.jl:307
  [3] sum(f::Function, a::Base.Generator{UnitRange{Int64}, var"#7#9"{Float64}}; kw::@Kwargs{dims::Colon})
    @ Base ./reduce.jl:535
  [4] sum(a::Base.Generator{UnitRange{Int64}, var"#7#9"{Float64}}; kw::@Kwargs{dims::Colon})
    @ Base ./reduce.jl:564
  [5] frule(::Tuple{…}, ::typeof(sum), x::Base.Generator{…}; dims::Function)
    @ ChainRules ~/.julia/packages/ChainRules/pEOSw/src/rulesets/Base/mapreduce.jl:10
  [6] frule(::Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.Tangent{…}}, ::typeof(sum), x::Base.Generator{UnitRange{…}, var"#7#9"{…}})
    @ ChainRules ~/.julia/packages/ChainRules/pEOSw/src/rulesets/Base/mapreduce.jl:9
  [7] (::Diffractor.∂☆internal{1})(::ZeroBundle{1, typeof(sum)}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
    @ Diffractor ~/.julia/packages/Diffractor/QY5lC/src/stage1/forward.jl:110
  [8] (::Diffractor.∂☆{1})(::ZeroBundle{1, typeof(sum)}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
    @ Diffractor ~/.julia/packages/Diffractor/QY5lC/src/stage1/forward.jl:151
  [9] #6
    @ ./REPL[5]:1 [inlined]
 [10] (::Diffractor.∂☆recurse{1})(::ZeroBundle{1, var"#6#8"}, ::Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}})
    @ Diffractor ~/.julia/packages/Diffractor/QY5lC/src/stage1/recurse_fwd.jl:0
 [11] (::Diffractor.∂☆internal{1})(::ZeroBundle{1, var"#6#8"}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
    @ Diffractor ~/.julia/packages/Diffractor/QY5lC/src/stage1/forward.jl:112
 [12] (::Diffractor.∂☆{1})(::ZeroBundle{1, var"#6#8"}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
    @ Diffractor ~/.julia/packages/Diffractor/QY5lC/src/stage1/forward.jl:151
 [13] (::Diffractor.var"#pushforward#369"{var"#6#8", Tuple{Float64}})(vs::Tuple{Float64})
    @ Diffractor ~/.julia/packages/Diffractor/QY5lC/src/AbstractDifferentiation.jl:17
 [14] jacobian(b::DiffractorForwardBackend, f::Function, args::Float64)
    @ Diffractor ~/.julia/packages/AbstractDifferentiation/1Cavg/src/AbstractDifferentiation.jl:537
 [15] derivative(ab::DiffractorForwardBackend, f::Function, xs::Float64)
    @ AbstractDifferentiation ~/.julia/packages/AbstractDifferentiation/1Cavg/src/AbstractDifferentiation.jl:34
 [16] top-level scope
    @ REPL[5]:1
Some type information was truncated. Use `show(err)` to see complete types.

This appears to be caused by the frule definition here: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/mapreduce.jl#L9 which I believe may be a junk method, not only because of it assuming you can do dims=:, but also because it assumes that sum(ẋ; dims=dims) makes sense, but that appears to be incorrect since is going to be a Tangent so this sum isn't actually summing the contents of the Tangent, but rather giving the outer thing (could be misunderstanding though).

I tried adding methods like

function frule((_, ẋ), ::typeof(sum), x::Generator;)
   return sum(x), sum(ẋ)
end

but that gave incorrect answers I think because of the sum(ẋ).

Instead, I found that if I just did Base.delete_method on the frule(::Tuple{…}, ::typeof(sum), x::Base.Generator{…}; dims::Function) method, I got the right results.

oxinabox commented 10 months ago

I think yes, that this frule: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/mapreduce.jl#L9 should just be deleted. Or at least restricted to Array.

Would you like to make the PR?