Open mcabbott opened 1 year ago
Calling frule_via_ad(cfg, (NoTangent(), one(x)), f, x) to work out the derivative works for numbers but not in general. So this path:
frule_via_ad(cfg, (NoTangent(), one(x)), f, x)
https://github.com/JuliaDiff/ChainRules.jl/blob/5855c10bdbe691fc07822752f5b5865b9cea44d3/src/rulesets/Base/broadcast.jl#L104-L110
fails if broadcasting over an array of Ref, or almost any struct:
Ref
struct
julia> using ChainRules, ChainRulesTestUtils julia> CFG = ChainRulesTestUtils.TestConfig(); julia> ChainRules.split_bc_forwards(CFG, only, [Ref(1.0), Ref(2.0)]) ERROR: MethodError: no method matching one(::Base.RefValue{Float64}) Stacktrace: [1] (::ChainRules.var"#1732#1734"{typeof(frule_via_ad), ChainRulesTestUtils.TestConfig, typeof(only)})(a::Base.RefValue{Float64}) @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:109 ... [10] StructArray @ ~/.julia/packages/StructArrays/dNQpc/src/structarray.jl:254 [inlined] [11] unzip_broadcast(f::ChainRules.var"#1732#1734"{typeof(frule_via_ad), ChainRulesTestUtils.TestConfig, typeof(only)}, args::Vector{Base.RefValue{Float64}}) @ ChainRules ~/.julia/dev/ChainRules/src/unzipped.jl:40 [12] split_bc_inner @ ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:108 [inlined] [13] split_bc_forwards(cfg::ChainRulesTestUtils.TestConfig, f::typeof(only), arg::Vector{Base.RefValue{Float64}}) julia> struct A594 x::Float64 end; julia> ChainRules.split_bc_forwards(CFG, x -> x.x, A594.(1:3)) ERROR: MethodError: no method matching one(::A594)
I don't see a requirement to define one (or perhaps oneunit) here https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/tangents.html#Operations-on-a-tangent-type but perhaps there ought to be such a thing?
one
oneunit
I started to work on code that can define a basis seed here: https://github.com/JuliaComputing/Humpty.jl/blob/main/src/basis.jl Maybe it is something we should think about.
Calling
frule_via_ad(cfg, (NoTangent(), one(x)), f, x)
to work out the derivative works for numbers but not in general. So this path:https://github.com/JuliaDiff/ChainRules.jl/blob/5855c10bdbe691fc07822752f5b5865b9cea44d3/src/rulesets/Base/broadcast.jl#L104-L110
fails if broadcasting over an array of
Ref
, or almost anystruct
:I don't see a requirement to define
one
(or perhapsoneunit
) here https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/tangents.html#Operations-on-a-tangent-type but perhaps there ought to be such a thing?