JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
258 stars 62 forks source link

Cannot generate `frule` seed via `one(x)` #618

Open mcabbott opened 1 year ago

mcabbott commented 1 year ago

Calling frule_via_ad(cfg, (NoTangent(), one(x)), f, x) to work out the derivative works for numbers but not in general. So this path:

https://github.com/JuliaDiff/ChainRules.jl/blob/5855c10bdbe691fc07822752f5b5865b9cea44d3/src/rulesets/Base/broadcast.jl#L104-L110

fails if broadcasting over an array of Ref, or almost any struct:

julia> using ChainRules, ChainRulesTestUtils

julia> CFG = ChainRulesTestUtils.TestConfig();

julia> ChainRules.split_bc_forwards(CFG, only, [Ref(1.0), Ref(2.0)])
ERROR: MethodError: no method matching one(::Base.RefValue{Float64})
Stacktrace:
  [1] (::ChainRules.var"#1732#1734"{typeof(frule_via_ad), ChainRulesTestUtils.TestConfig, typeof(only)})(a::Base.RefValue{Float64})
    @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:109
...
 [10] StructArray
    @ ~/.julia/packages/StructArrays/dNQpc/src/structarray.jl:254 [inlined]
 [11] unzip_broadcast(f::ChainRules.var"#1732#1734"{typeof(frule_via_ad), ChainRulesTestUtils.TestConfig, typeof(only)}, args::Vector{Base.RefValue{Float64}})
    @ ChainRules ~/.julia/dev/ChainRules/src/unzipped.jl:40
 [12] split_bc_inner
    @ ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:108 [inlined]
 [13] split_bc_forwards(cfg::ChainRulesTestUtils.TestConfig, f::typeof(only), arg::Vector{Base.RefValue{Float64}})

julia> struct A594 x::Float64 end;

julia> ChainRules.split_bc_forwards(CFG, x -> x.x, A594.(1:3))
ERROR: MethodError: no method matching one(::A594)

I don't see a requirement to define one (or perhaps oneunit) here https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/tangents.html#Operations-on-a-tangent-type but perhaps there ought to be such a thing?

oxinabox commented 1 year ago

I started to work on code that can define a basis seed here: https://github.com/JuliaComputing/Humpty.jl/blob/main/src/basis.jl Maybe it is something we should think about.