JuliaDiff / ForwardDiff.jl

Forward Mode Automatic Differentiation for Julia
Other
887 stars 141 forks source link

Inequality, ε-balls, and accidental structural zeros #480

Open mcabbott opened 3 years ago

mcabbott commented 3 years ago

Consider this function:

sq(x) = x==1 ? one(x) : x^2

@test FiniteDifferences.central_fdm(5, 1)(sq, 1) ≈ 2.0
@test_broken ForwardDiff.derivative(sq, 1.0) == 2.0

Here ForwardDiff gets the wrong answer, according to your first calculus class: The derivative is defined by taking limits, evaluating sq(x + ε) for some small ε, and these always see the continuum x^2 not the special point.

One to think about this is to say that x==1 really means abs(x-1) < ζ with some tiny ζ, which we keep finite until we are sure we aren't confused. The calculus class assumption is that ζ << ε.

The assumption of ForwardDiff is the opposite. Its Dual(x,1) encodes a perturbation x + 1ε with ε smaller than everything else around, and in particular ε << ζ. Or in other words, sq is viewed as being piecewise continuous, with a small flat area of width , which is still large enough for us to see that its slope is zero.

Of course nobody really writes contrived examples like sq. But they do write things like this:

function prod1(xs::Vector)
    p = one(eltype(xs))
    for x in xs
        p = p * x
    end
    p
end

function prod2(xs::Vector)
    p = one(eltype(xs))
    for x in xs
        p = p * x
        p == 0 && break  # exit early once you know the answer
    end
    p
end

@test ForwardDiff.gradient(prod1, [1,2,0,4,0,6]) == zeros(6)
@test_broken ForwardDiff.gradient(prod2, [1,2,0,4,0,6]) == zeros(6)

This has almost the same problem as #197, where det(A) tests for istriu(A) || istril(A) before calling a simpler branch. The fact that f(x,y) == g(x,y) when y==0 does not imply that df/dy == dg/dy. So it seems AD ought not to take that branch.

In which case we want something like this:

Base.:(==)(x::Dual, y::Int) = x.value == y && iszero(x.partials)
Base.:(!=)(x::Dual, y::Int) = x.value != y || !iszero(x.partials)

This fixes the tests above, and (a slightly more careful version) fixes #197 and #407.

However, it means that fun(Dual(x,1)).value need not be equal to fun(x), on a discontinuous function. Although fun(Dual(x,0)).value should still be equal, @assert zero(x) == 0 isn't broken, and there should be no problems where functions use things like zero(eltype(xs)) for type-stability.

The idea that the forward evaluation is unchanged is often thought of as an axiom of AD, but for discontinuous functions, I think that's another way of saying ε << ζ. Which is a choice. And one that your calculus teacher would disapprove of. The point of evaluating a function with dual numbers is, presumably, to find derivatives, so finding them correctly ought to have a higher priority.

There are other comparisons to think about, for example:

sq2(x) = x>1 ? x^2 : x<1 ? x^2 : one(x)

clamp2(x, lo=0, hi=1) = x>hi ? oftype(x,hi) : x<lo ? oftype(x,lo) : x
clamp3(x, lo=0, hi=1) = x>=hi ? oftype(x,hi) : x<=lo ? oftype(x,lo) : x

[ForwardDiff.derivative(cl, 1.0) for cl in [x->clamp(x,0,1), clamp2, clamp3]] == [1,1,0]
[central_fdm(5, 1)(cl, 1.0) for cl in [x->clamp(x,0,1), clamp2, clamp3]] ≈ [0.5, 0.5, 0.5]

I'm not sure how often simulating x==1 as in sq2(x) happens in the wild. Perhaps from some combination like f(x) = relu(x) + 0.1*relu(-x)?

But clamping parameters to some range is routine. Your calculus teacher would throw an error here, but that's probably not the most helpful response for the computer.

Returning a nonzero derivative here is useful because, if this is some parameter being optimised, it means gradient descent won't get stuck against the wall, when the gradient is away from it. So you can argue that the ability to choose which sub-gradient ForwardDiff will use is a feature. The 0.5 gradient alla FiniteDifferences would also be fine for getting un-stuck, but it's very difficult to picture how ForwardDiff could evaluate both branches, and easy to picture doing so having awful side-effects.

Here is one way to relate the present rule for >(::Dual, ::Real) and >=(::Dual, ::Real) to the finite-everything ζ << ε story. We can say that while the ε-ball overlaps with both sides, the vote from the longer side (longer by about ) always wins by a hair:

----------(==========1==========)----------  abs(x-1) < ε
---------------------1-(===================  x > 1+ζ
          +++++++++++++..........            gradient votes, clamp2(1.0)

Trying out the above ==(::Dual, ::Real) rule, it looks like the tests of this package all pass, except for the ones explicitly testing such rules. It would be interesting to know if this breaks any other uses in the wild. It would also be interesting to think up other pathological examples, maybe I've missed something important.

Also:

[ForwardDiff.derivative(f, 0.0) for f in [sq, sq2, sq3, sq4, sq5]] == [0,0,0,0,0] [ForwardDiff.hessian(x -> f(x[1]), [0.0])[1] for f in [sq, sq2, sq3, sq4, sq5]] == [2,2,-2,2,2]


A rule was suggested there in which `x > y` behaves differently for `x.value == y.value`, breaking such ties by comparing `x.partials > y.partials`. In the `clamp2` example, whether you get stuck against the wall presumably shouldn't depend on whether you minimise `loss(x)` or maximise `-loss(x)`, so we probably don't want to compare `x.partials .> 0` when only `x` is a dual number. But the rule when both `x` and `y` are dual might be worth some more thought. 
mcabbott commented 3 years ago

There's another example here: https://github.com/TuringLang/DistributionsAD.jl/pull/23#issuecomment-575671303. Behaviour on tagged ForwardDiff:

using ForwardDiff, LinearAlgebra
A = Matrix(I, 2,2)

ForwardDiff.gradient(x -> sum(x \ [1,3]), A)           # [-1  0; 0 -3]
ForwardDiff.gradient(x -> sum(cholesky(x) \ [1,3]), A) # [-1 -4; 0 -3] 

Aplus2 = [1 0; 0.001 1]  # perturb A[2]
sum(Aplus2 \ [1,3]) - sum(A \ [1,3])                      # -0.001
sum(cholesky(Aplus2) \ [1,3]) - sum(cholesky(A) \ [1,3])  # PosDefException: matrix is not Hermitian

Aplus23 = [1 0.001; 0.001 1]  # perturb A[2] and A[3] together
sum(Aplus23 \ [1,3]) - sum(A \ [1,3])                     # -0.004
sum(cholesky(Aplus23) \ [1,3]) - sum(cholesky(A) \ [1,3]) # -0.004

Perhaps one could claim the cholesky result is particular sub-gradient. But will the algorithm always behave this way? I'm don't know. The one for simple \ is silently wrong --- or at least, it has promoted A to be structurally diagonal. The Jacobians are:

ForwardDiff.jacobian(x -> x \ [1,3], A)
 # -1.0  -0.0  -0.0  -0.0
 # -0.0  -0.0  -0.0  -3.0

ForwardDiff.jacobian(x -> cholesky(x) \ [1,3], A)
 # -1.0  0.0  -3.0   0.0
 #  0.0  0.0  -1.0  -3.0

With the proposed change to ==, the results instead match finite differences. Here \ may choose a different algorithm, but this shouldn't cause a discontinuity of the forward pass.

ForwardDiff.gradient(x -> sum(x \ [1,3]), A)           # [-1 -3; -1 -3]
ForwardDiff.gradient(x -> sum(cholesky(x) \ [1,3]), A) # ERROR: PosDefException

ForwardDiff.jacobian(x -> x \ [1,3], A)
 # -1.0   0.0  -3.0   0.0
 #  0.0  -1.0   0.0  -3.0

Aplus2 = [1 0; 0.001 1]
Aplus2 \ [1, 3] - A \ [1, 3]  # [0, -0.001]

Aplus3 = [1 0.001; 0 1]
Aplus3 \ [1, 3] - A \ [1, 3]  # [-0.003, 0]

The example just above that, in https://github.com/TuringLang/DistributionsAD.jl/pull/23#issuecomment-574965028, also gives an error on this branch, also from cholesky. As does a finite perturbation.

mateuszbaran commented 3 years ago

I didn't fully think it through but wouldn't it break functions like f(x) = x == 0.0 ? 1.0 : sin(x)/x? What if I replaced == 0.0 with iszero or ≈ 0.0?

mateuszbaran commented 3 years ago

Right now if I wanted correct derivatives of such functions up to a certain order I'd just put in the "measure zero branch" a Taylor approximation of the function.

mcabbott commented 3 years ago

This f would indeed change. However, it has problems on a region bigger than one point, and so it would be better anyway for the smoothed branch to be used within some finite interval. For example:

julia> fpi(x) = x==0 ? 1.0 : sin(pi*x) / (pi*x);

julia> fpi(1e-40)
1.0

julia> ForwardDiff.derivative(fpi, 1e-40)
1.2089258196146292e24

This is a variant of #466, for which the solution is

julia> cosc(1e-40) 
-3.2898681336964526e-40

Even if you don't have an exact closed-form derivative, I think that you would usually want to replace f(x) with a constant / polynomial within some small interval, not at a single point. Perhaps always, are there exceptions?

Since the fallback is iszero(x) = x == zero(x) I'd be hesitant to mess with that, but isapprox could be given a different behaviour for dual numbers. At zero it doesn't by default widen the interval:

julia> ForwardDiff.Dual(1e-40,0) ≈ 0
false

julia> nextfloat(0.0) ≈ 0.0
false
mcabbott commented 3 years ago

One more data point is that https://github.com/JuliaPhysics/Measurements.jl gets this right according to my argument above. While its error bars aren't exactly dual numbers, they are a related species. Here's a version of https://github.com/JuliaDiff/ForwardDiff.jl/issues/536 in which mul! does not take the shortcut:

julia> using Measurements

julia> λ = measurement(0, 0.1)
0.0 ± 0.1

julia> iszero(λ)
false

julia> A = measurement.([1 2; 3 4], 0.1);

julia> B = measurement.([5 6; 7 8], 0.1);

julia> A * (B * λ)
2×2 Matrix{Measurement{Float64}}:
 0.0±1.9  0.0±2.2
 0.0±4.3  0.0±5.0

julia> mul!(similar(A), A, B, λ, 0)  # like issue 536
2×2 Matrix{Measurement{Float64}}:
 0.0±1.9  0.0±2.2
 0.0±4.3  0.0±5.0
ChrisRackauckas commented 2 years ago

The idea that the forward evaluation is unchanged is often thought of as an axiom of AD

And it's definitely a false one. It's an axiom for computer scientists, but not for numerical analysts 😅. A nice example is in the space of ODEs. Automatic differentiation is equivalent to solving the expanded ODE known as the forward sensitivity equations, essentially:

u' = f(u,p,t)
d/dt (du/dp) = df/du du/dp + df/dp

Straight automatic differentiation is equivalent to solving the expanded ODE with the adaptive error controls only applying to the first part of the equation u' = f(u,p,t). Are there ODEs for which the second part is unstable when adaptivity is only applied to the first part? Yes. https://github.com/SciML/DiffEqSensitivity.jl/issues/273 is a real-world example where this came up. The solution was https://github.com/SciML/DiffEqSensitivity.jl/issues/273, i.e. the norm used in the ODE solver has to account for the pseudo-ODEs if you want it to be stable, and so the default norm that is used adds the partials to the primal part.

https://github.com/SciML/DiffEqBase.jl/blob/v6.83.1/src/forwarddiff.jl#L31-L34

This means that solving with ForwardDiff gives different stepping behavior, but if you don't do that, then there will be cases where you have "infinite" derivative because of numerical instability in the derivative calculation even when the primal is stable. So definitely, this axiom does not hold for the realities of numerical computing.

So back to the core of the thread, I definitely agree with you. In fact, DiffEq specializes its interpolation computation in order to work around this kind of issue. Normally it would just pull sol.u[i] if the interpolant matches sol.t[i] directly, but it needs to still use the interpolation if it's a dual number since otherwise the derivative is zero. If it had this epsilon ball definition, the workaround to force Dual numbers to not take the sol.t[i] == t branch would be eliminated.

mcabbott commented 1 year ago

Perhaps it's worth noting that many other AD systems have the same problem. On the example above:

julia> Zygote.gradient(prod1, [1,2,0,4,0,6])
([0.0, 0.0, 0.0, 0.0, 0.0, 0.0],)

julia> Zygote.gradient(prod2, [1,2,0,4,0,6])  # wrong
([0.0, 0.0, 2.0, 0.0, 0.0, 0.0],)

julia> Tracker.gradient(prod1, [1,2,0,4,0,6])  # (after removing :::Vector restriction)
([0.0, 0.0, 0.0, 0.0, 0.0, 0.0] (tracked),)

julia> Tracker.gradient(prod2, [1,2,0,4,0,6])  # wrong
([0.0, 0.0, 2.0, 0.0, 0.0, 0.0] (tracked),)

julia> dx = zeros(6); Enzyme.autodiff(prod1, Duplicated([1,2,0,4,0,6.], dx)); (dx,)
([0.0, 0.0, 0.0, 0.0, 0.0, 0.0],)

julia> dx = zeros(6); Enzyme.autodiff(prod2, Duplicated([1,2,0,4,0,6.], dx)); (dx,)  # wrong
([0.0, 0.0, 2.0, 0.0, 0.0, 0.0],)

I think Zygote cannot fix this, as it does not know which variables are active when it transforms code. Tracker could surely change == for TrackedReal much like #481 here. Enzyme does know about activitiy; the relevant issue there is https://github.com/EnzymeAD/Enzyme.jl/issues/114

devmotion commented 1 year ago

I think @wsmoses's comment is quite interesting. I've never viewed it that way but I guess one could say that all these AD systems are correct as they correctly return the gradient of the function that is implemented/defined by prod2 - and they seem incorrect only if one considers prod2 just as an optimization of prod1 and actually would like to obtain the gradients of the function implemented by prod1. In any case, I guess these differences are surprising for most users who, I assume, don't expect to see any difference here and have not thought about the implications of different implementations for AD.

bvdmitri commented 1 year ago

It would the best if it was only surprising, but that behaviour for users is frustrating at least. From this discussion and discussion on Slack it turned out that this is known issue and property of many other AD backends, but for some reason it is not communicated well to end users who may rely on this in critical systems or extensive simulations. That is important, because it happens not only in toy examples, but in real code as well. A good example from our case is the dot(x, A * x) and dot(x, A, x), which produce different hessians if x is a zeroed vector (and 3-argument function produces not only different, but completely incorrect result too). This is happens all the time when you evaluate the Gaussian logpdf in its mean. And the dot function is not even written by us, but has been taken from the Julia's LinearAlgebra. Distributions.jl are not affected by pure luck, becausePDMats use 2-argument dot function.

This kind-of limitations should be communicated (e.g. in documentation) better to end users, who indeed have not thought about the implications of different implementations for AD. AD systems position themselves as fast and more accurate alternative to finite differences, but do not document clear (known!) pitfalls. That is bad and that is not specific to ForwardDiff. ForwardDiff is an amazing library, but why does a user should start thinking of potential implications of different AD backends if there is no indication that something may go wrong in the first place? Thats an open question of course.

It's great that ForwardDiff has a solution and it looks like the fix has been merged in master. I'm looking forward for the fix to be released.