I think this is a ChainRules.jl problem, not a Diffractor.jl problem but I'm not 100% sure. Here's a MWE:
julia> using Diffractor: DiffractorForwardBackend; using AbstractDifferentiation: derivative, jacobian
julia> derivative(DiffractorForwardBackend(), x -> sum(x+i for i ∈ 1:100), 1.0)
ERROR: MethodError: no method matching mapfoldl(::typeof(identity), ::typeof(Base.add_sum), ::Base.Generator{UnitRange{Int64}, var"#7#9"{Float64}}; dims::Colon)
Closest candidates are:
mapfoldl(::Any, ::Any, ::Any; init) got unsupported keyword argument "dims"
@ Base reduce.jl:175
mapfoldl(::F, ::R, ::StaticArraysCore.StaticArray; init) where {F, R} got unsupported keyword argument "dims"
@ StaticArrays ~/.julia/packages/StaticArrays/eGKzB/src/mapreduce.jl:255
Stacktrace:
[1] kwerr(::@NamedTuple{dims::Colon}, ::Function, ::Function, ::Function, ::Base.Generator{UnitRange{Int64}, var"#7#9"{Float64}})
@ Base ./error.jl:165
[2] mapreduce(f::Function, op::Function, itr::Base.Generator{UnitRange{Int64}, var"#7#9"{Float64}}; kw::@Kwargs{dims::Colon})
@ Base ./reduce.jl:307
[3] sum(f::Function, a::Base.Generator{UnitRange{Int64}, var"#7#9"{Float64}}; kw::@Kwargs{dims::Colon})
@ Base ./reduce.jl:535
[4] sum(a::Base.Generator{UnitRange{Int64}, var"#7#9"{Float64}}; kw::@Kwargs{dims::Colon})
@ Base ./reduce.jl:564
[5] frule(::Tuple{…}, ::typeof(sum), x::Base.Generator{…}; dims::Function)
@ ChainRules ~/.julia/packages/ChainRules/pEOSw/src/rulesets/Base/mapreduce.jl:10
[6] frule(::Tuple{ChainRulesCore.ZeroTangent, ChainRulesCore.Tangent{…}}, ::typeof(sum), x::Base.Generator{UnitRange{…}, var"#7#9"{…}})
@ ChainRules ~/.julia/packages/ChainRules/pEOSw/src/rulesets/Base/mapreduce.jl:9
[7] (::Diffractor.∂☆internal{1})(::ZeroBundle{1, typeof(sum)}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
@ Diffractor ~/.julia/packages/Diffractor/QY5lC/src/stage1/forward.jl:110
[8] (::Diffractor.∂☆{1})(::ZeroBundle{1, typeof(sum)}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
@ Diffractor ~/.julia/packages/Diffractor/QY5lC/src/stage1/forward.jl:151
[9] #6
@ ./REPL[5]:1 [inlined]
[10] (::Diffractor.∂☆recurse{1})(::ZeroBundle{1, var"#6#8"}, ::Diffractor.TangentBundle{1, Float64, Diffractor.TaylorTangent{Tuple{Float64}}})
@ Diffractor ~/.julia/packages/Diffractor/QY5lC/src/stage1/recurse_fwd.jl:0
[11] (::Diffractor.∂☆internal{1})(::ZeroBundle{1, var"#6#8"}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
@ Diffractor ~/.julia/packages/Diffractor/QY5lC/src/stage1/forward.jl:112
[12] (::Diffractor.∂☆{1})(::ZeroBundle{1, var"#6#8"}, ::Vararg{Diffractor.AbstractTangentBundle{1}})
@ Diffractor ~/.julia/packages/Diffractor/QY5lC/src/stage1/forward.jl:151
[13] (::Diffractor.var"#pushforward#369"{var"#6#8", Tuple{Float64}})(vs::Tuple{Float64})
@ Diffractor ~/.julia/packages/Diffractor/QY5lC/src/AbstractDifferentiation.jl:17
[14] jacobian(b::DiffractorForwardBackend, f::Function, args::Float64)
@ Diffractor ~/.julia/packages/AbstractDifferentiation/1Cavg/src/AbstractDifferentiation.jl:537
[15] derivative(ab::DiffractorForwardBackend, f::Function, xs::Float64)
@ AbstractDifferentiation ~/.julia/packages/AbstractDifferentiation/1Cavg/src/AbstractDifferentiation.jl:34
[16] top-level scope
@ REPL[5]:1
Some type information was truncated. Use `show(err)` to see complete types.
This appears to be caused by the frule definition here: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/mapreduce.jl#L9 which I believe may be a junk method, not only because of it assuming you can do dims=:, but also because it assumes that sum(ẋ; dims=dims) makes sense, but that appears to be incorrect since ẋ is going to be a Tangent so this sum isn't actually summing the contents of the Tangent, but rather giving the outer thing (could be misunderstanding though).
I tried adding methods like
function frule((_, ẋ), ::typeof(sum), x::Generator;)
return sum(x), sum(ẋ)
end
but that gave incorrect answers I think because of the sum(ẋ).
Instead, I found that if I just did Base.delete_method on the frule(::Tuple{…}, ::typeof(sum), x::Base.Generator{…}; dims::Function) method, I got the right results.
I think this is a ChainRules.jl problem, not a Diffractor.jl problem but I'm not 100% sure. Here's a MWE:
This appears to be caused by the
frule
definition here: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/mapreduce.jl#L9 which I believe may be a junk method, not only because of it assuming you can dodims=:
, but also because it assumes thatsum(ẋ; dims=dims)
makes sense, but that appears to be incorrect sinceẋ
is going to be aTangent
so thissum
isn't actually summing the contents of theTangent
, but rather giving the outer thing (could be misunderstanding though).I tried adding methods like
but that gave incorrect answers I think because of the
sum(ẋ)
.Instead, I found that if I just did
Base.delete_method
on thefrule(::Tuple{…}, ::typeof(sum), x::Base.Generator{…}; dims::Function)
method, I got the right results.