JuliaDiff / AbstractDifferentiation.jl

An abstract interface for automatic differentiation.
https://juliadiff.org/AbstractDifferentiation.jl/
MIT License
136 stars 18 forks source link

Handling of thunks and tangents #132

Open gdalle opened 7 months ago

gdalle commented 7 months ago

When the VJP is not an abstract array, things get weird

julia> import AbstractDifferentiation as AD

julia> import Zygote

julia> ad_backend = AD.ReverseRuleConfigBackend(Zygote.ZygoteRuleConfig())
AbstractDifferentiation.ReverseRuleConfigBackend{Zygote.ZygoteRuleConfig{Zygote.Context{false}}}(Zygote.ZygoteRuleConfig{Zygote.Context{false}}(Zygote.Context{false}(nothing)))

julia> AD.second_derivative(ad_backend, identity, 1)
ERROR: MethodError: no method matching length(::ChainRulesCore.NoTangent)

julia> AD.hessian(ad_backend, sum, [1.0])
ERROR: MethodError: no method matching size(::ChainRulesCore.Thunk{ChainRulesCore.var"#48#49"{ChainRulesCore.Thunk{ChainRulesCore.var"#48#49"{…}}}})
oxinabox commented 7 months ago

the correct thing to do with thunks is unthunk them before using them. The correct thing to do with NoTangent is generally to handle is specifically, or failing that to pull information from the primal. Though for NoTangent more the former, for ZeroTangent more the later (see how Diffractor turns ZeroTangent into zero_tangent.