JuliaDiff / DiffRules.jl

A simple shared suite of common derivative definitions
Other
74 stars 38 forks source link

Update abs diff rule to 0 at non-differentiable point #98

Closed agerlach closed 1 year ago

agerlach commented 1 year ago

This PR updates the diffrule for abs to return 0 at the non-differentiable point. The current implementation returns 1. Although valid, this can prevent convergence in gradient descent. The implementation in this PR is the behavior the ChainRules.jl docs advises.

This also comes with the added benefit of not requiring the type to support the ternary operator such as IntervalArithmetic.Interval. This is the use case that led me to make this PR.

using IntervalArithmetic, ForwardDiff

ForwardDiff.derivative(abs, -2.0 .. 2.0)
ERROR: TypeError: non-boolean (Interval{Float64}) used in boolean context
Stacktrace:
 [1] _abs_deriv(x::Interval{Float64})
   @ DiffRules ~/.julia/packages/DiffRules/wKSai/src/rules.jl:73
 [2] abs
   @ ~/.julia/packages/ForwardDiff/vXysl/src/dual.jl:240 [inlined]

With this PR:

ForwardDiff.derivative(abs, -2.0 .. 2.0) # [-1, 1]
ForwardDiff.derivative(abs, 0.0 .. 2.0)  # [0, 1]
ForwardDiff.derivative(abs, -3.0 .. 1.0) # [ -1, -1]

The diffrule for abs has the following comment, which I'm not sure how to interpret. As it doesn't work with IntervalArithmetic.Interval or Intervals.Intervel. Additionally, the current definition assumes that 0 is not in the interval.

https://github.com/JuliaDiff/DiffRules.jl/blob/2001650a4a009d2136f89496f2d22fe7fb04dfbd/src/rules.jl#L71

oxinabox commented 1 year ago

Tracker.jl breakage is unrelated.

I believe this is correct. (But of course I do, I am explicitly proponents of this property.)

I will merge this tomorrow unless someone raises good objections.

agerlach commented 1 year ago

Re: Tracker.jl I was hoping that was the case. Thanks

codecov[bot] commented 1 year ago

Codecov Report

Patch coverage: 100.00% and no project coverage change.

Comparison is base (2001650) 97.86% compared to head (fee3857) 97.86%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #98 +/- ## ======================================= Coverage 97.86% 97.86% ======================================= Files 3 3 Lines 187 187 ======================================= Hits 183 183 Misses 4 4 ``` | [Impacted Files](https://app.codecov.io/gh/JuliaDiff/DiffRules.jl/pull/98?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff) | Coverage Δ | | |---|---|---| | [src/rules.jl](https://app.codecov.io/gh/JuliaDiff/DiffRules.jl/pull/98?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=JuliaDiff#diff-c3JjL3J1bGVzLmps) | `100.00% <100.00%> (ø)` | |

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

devmotion commented 1 year ago

The diffrule for abs has the following comment, which I'm not sure how to interpret.

git blame shows it was added in #33 and there the explanation is arguably a bit clearer: DiffRules._abs_deriv is intended as a hook that downstream packages such as the ones for interval arithmetic can overload. Based on a JuliaHub search (https://juliahub.com/ui/Search?q=_abs_deriv&type=code) it seems that no public package actually overloads it (anymore).

devmotion commented 1 year ago

Some additional historical context: It seems the rule for abs was originally added in https://github.com/JuliaDiff/DiffRules.jl/pull/11, and there it was suggested to use signbit since at that time it was used in abs(::ForwardDiff.Dual) (which was later removed in https://github.com/JuliaDiff/ForwardDiff.jl/pull/311).

agerlach commented 1 year ago

@devmotion Thanks for the extra context.

andreasnoack commented 1 year ago

I think we should revert this. It breaks higher order derivatives for some differentiable functions. E.g.

julia> ForwardDiff.hessian(t -> abs(t[1])^2, [0.0])
1×1 Matrix{Float64}:
 2.0

(TestDiffRules) pkg> add DiffRules@1.14
   Resolving package versions...
    Updating `~/TestDiffRules/Project.toml`
  [b552c78f] ↑ DiffRules v1.13.0 ⇒ v1.14.0
    Updating `~/TestDiffRules/Manifest.toml`
  [b552c78f] ↑ DiffRules v1.13.0 ⇒ v1.14.0

julia> ForwardDiff.hessian(t -> abs(t[1])^2, [0.0])
1×1 Matrix{Float64}:
 0.0

The example here is of course trivial but abs is used in many places for numerical reasons in differentiable functions such as https://github.com/JuliaStats/Distributions.jl/blob/2dee35e13eacb0909c6b2189f229ce93c04d2560/src/univariate/continuous/logistic.jl#L82.