Closed agerlach closed 1 year ago
Tracker.jl breakage is unrelated.
I believe this is correct. (But of course I do, I am explicitly proponents of this property.)
I will merge this tomorrow unless someone raises good objections.
Re: Tracker.jl I was hoping that was the case. Thanks
Patch coverage: 100.00
% and no project coverage change.
Comparison is base (
2001650
) 97.86% compared to head (fee3857
) 97.86%.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
The diffrule for
abs
has the following comment, which I'm not sure how to interpret.
git blame shows it was added in #33 and there the explanation is arguably a bit clearer: DiffRules._abs_deriv
is intended as a hook that downstream packages such as the ones for interval arithmetic can overload. Based on a JuliaHub search (https://juliahub.com/ui/Search?q=_abs_deriv&type=code) it seems that no public package actually overloads it (anymore).
Some additional historical context: It seems the rule for abs
was originally added in https://github.com/JuliaDiff/DiffRules.jl/pull/11, and there it was suggested to use signbit
since at that time it was used in abs(::ForwardDiff.Dual)
(which was later removed in https://github.com/JuliaDiff/ForwardDiff.jl/pull/311).
@devmotion Thanks for the extra context.
I think we should revert this. It breaks higher order derivatives for some differentiable functions. E.g.
julia> ForwardDiff.hessian(t -> abs(t[1])^2, [0.0])
1×1 Matrix{Float64}:
2.0
(TestDiffRules) pkg> add DiffRules@1.14
Resolving package versions...
Updating `~/TestDiffRules/Project.toml`
[b552c78f] ↑ DiffRules v1.13.0 ⇒ v1.14.0
Updating `~/TestDiffRules/Manifest.toml`
[b552c78f] ↑ DiffRules v1.13.0 ⇒ v1.14.0
julia> ForwardDiff.hessian(t -> abs(t[1])^2, [0.0])
1×1 Matrix{Float64}:
0.0
The example here is of course trivial but abs
is used in many places for numerical reasons in differentiable functions such as https://github.com/JuliaStats/Distributions.jl/blob/2dee35e13eacb0909c6b2189f229ce93c04d2560/src/univariate/continuous/logistic.jl#L82.
This PR updates the diffrule for
abs
to return 0 at the non-differentiable point. The current implementation returns 1. Although valid, this can prevent convergence in gradient descent. The implementation in this PR is the behavior the ChainRules.jl docs advises.This also comes with the added benefit of not requiring the type to support the ternary operator such as
IntervalArithmetic.Interval
. This is the use case that led me to make this PR.With this PR:
The diffrule for
abs
has the following comment, which I'm not sure how to interpret. As it doesn't work withIntervalArithmetic.Interval
orIntervals.Intervel
. Additionally, the current definition assumes that 0 is not in the interval.https://github.com/JuliaDiff/DiffRules.jl/blob/2001650a4a009d2136f89496f2d22fe7fb04dfbd/src/rules.jl#L71