SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
329 stars 71 forks source link

Inconsistent gradient for toy adaptive ODE simulation #1094

Closed frankschae closed 3 weeks ago

frankschae commented 3 weeks ago

Question❓

Is it expected that the gradient for ReverseDiff and FiniteDiff disagree, even though the solver steps to almost exactly the same times? The difference goes away when the tolerances are decreased. We spotted the same thing to happen in diffrax https://github.com/patrick-kidger/diffrax/issues/499. (For ForwardDiff, one can verify that the solver takes different steps and consequently we obtain a different gradient as well. CC: @lockwo)

using OrdinaryDiffEq
using FiniteDiff, ForwardDiff, Statistics, Zygote, ReverseDiff, SciMLSensitivity

function f(X, args, t)
    y1, y2 = X
    dy1 = -273 / 512 * y1
    dy2 = -1 / 160 * y1 - (-785 / 512 + sqrt(2) / 8) * y2
    return [dy1, dy2]
end

u0 = [1.0, 1.0]
args = ones(1)
odeprob = ODEProblem(f, u0, (0.0, 3.0), args)

function loss(u0)
    _prob = remake(odeprob, u0=u0)
    _sol = (solve(_prob, Heun(),
        dt=0.1,
        abstol=0.1,
        reltol=0.1,
        save_everystep=true,
        save_start = false,
        #adaptive=false
        controller=IController(), #CustomController(),
        sensealg=ReverseDiffAdjoint()
    ))
    @show (_sol.t)
    _sol = _sol[end]

    return sum(abs2, _sol)
end

function finite_diff(u0)
    eps = 1e-5
    v1 = loss([u0[1] + eps, u0[2]])
    v2 = loss([u0[1], u0[2]])
    v3 = loss([u0[1], u0[2] + eps])
    #v4 = loss([u0[1], u0[2] - eps / 2])

    @show v2

    [v1 - v2, v3 -v1] / (eps)
end

begin
    @show finite_diff(u0)
    println("FiniteDiff")
    grad1 = FiniteDiff.finite_difference_gradient(loss, u0)
    println("Forward")
    grad2 = ForwardDiff.gradient(loss, u0)
    println("Reverse")
    grad3 = Zygote.gradient(loss, u0)[1]
    @show grad1 grad2 grad3
end
_sol.t = [0.1, 0.613929509825019, 1.1997800701371129, 1.7608249433819432, 2.289930281088923, 2.7977639535014687, 3.0]
_sol.t = [0.1, 0.6139295972860885, 1.1997802572988525, 1.760825190439614, 2.2899305515863184, 2.7977642464962393, 3.0]
_sol.t = [0.1, 0.6139284165641612, 1.1997779517656673, 1.7608221635535624, 2.2899271418305647, 2.7977606306650937, 3.0]
v2 = 2286.5798548666576
finite_diff(u0) = [4.671882288675988, 4825.2249315282825]
FiniteDiff
_sol.t = [0.1, 0.613929509825019, 1.199780114366179, 1.7608249943840808, 2.28993033847855, 2.797764017021789, 3.0]
_sol.t = [0.1, 0.6139295972860885, 1.1997803015279263, 1.7608252770250412, 2.2899306781168063, 2.7977643468051436, 3.0]
_sol.t = [0.1, 0.6139288538679462, 1.199778799113144, 1.7608232612500783, 2.2899284095313726, 2.7977619646944447, 3.0]
_sol.t = [0.1, 0.6139302969757168, 1.1997816219077029, 1.7608269783142547, 2.28993257339865, 2.7977663637197554, 3.0]
Forward
_sol.t = [0.1, 0.728028559619219, 1.4979699401846973, 2.275540178765593, 3.0]
Reverse
_sol.t = [0.1, 0.6139295972860885, 1.1997802572988525, 1.760825190439614, 2.2899305515863184, 2.7977642464962393, 3.0]
grad1 = [3.4720645693875793, 4831.212995065322]
grad2 = [-11.629850313906191, 3561.5179283293633]
grad3 = [-15.024234205512382, 4588.183943938828]
ChrisRackauckas commented 3 weeks ago

That's not unexpected? Native autodiff of adaptive solvers just always has this issue, you can literally give it ODEs where its error is epsilon for any epsilon that you want to have, the paper shows you a simple way to do the construction and from there you can see how to generalize it to this full result. In ReverseDiff, since it does not have the correction, the reverse mode process generates an approximating ODE, not a direct calculation of the derivative but an approximation to a process that converges to the correct ODE as dt->0. However, the adaptivity does not directly control dt for that method, just for the forward method. So you pretty much know nothing in that case other than that it converges, but the derivative is still an ODE so you have error but it's not necessarily even bounded (and there's counter examples you can construct where it is unbounded, so that's just truly unbounded). The only reason why that's not true for ForwardDiff is because it corrects for the stepping going forward in a way that's not naive autodiff, which of course has its own oddities but it's a requirement in order to ensure a convergent method.

The point really is that (a) your gradient is also given by an ODE, and you're simply approximating it, and so that approximation needs to be sufficiently accurate otherwise you'll have solver-induced error on it, and (b) not all methods (native autodiff of solvers especially) are convergent as tolerance goes to zero in all cases. In other words, if your tolerances and time steps are large then your approximation to the gradient will also have a large error, and that's just inherent in the fact that you're only ever approximating the gradient process. You can't even guarantee that a dt chosen for the forward process is also a stable solve for the reverse process (which means reverse mode AD can generate an unstable reverse method from a stable reverse method! There's some nice examples to make this unconditionally true BTW). You only expect convergence, and that if dt is sufficiently small and solvers are chosen correctly that you get a stable method, but you cannot really expect much at higher tolerances, just like you can't for any ODE solver (since tolerances are local properties and thus do not guarantee stability with chosen dts).

lockwo commented 3 weeks ago

For the error between the true derivative of d/dx0 ODE(x0) and the reverse AD derivative of x0, I see that our derivative should be off. But comparing finite diff and reverse AD, we see finite diff steps almost the exact same spot as reverse diff (well less than epsilon). So I would expect finite diff and reverse AD to be "wrong" from the true grad, but wrong in a similar manner?

Additionally, independent of the accuracy of AD to finite diff, how much does ad error to true gradient matter in practice? SGD usually has approximations of the true gradient normally (and you can add noise/do all sort of weird stuff to gradients and still converge in ML), so would this meaningfully affect training. If it does, and training w/ reverse diff on systems such as this is a dangerous idea, we would need to turn down the tolerances, but as you remarked "not all methods (native autodiff of solvers especially) are convergent as tolerance goes to zero in all cases" this means we can't just decrease tolerance and fix the problem (universally). Are these methods where turning down the tolerance isn't a valid solution to get AD better very chaotic systems or fluid dynamics systems where FP accuracy matters, or would some toy ODEs also break?

ChrisRackauckas commented 3 weeks ago

Are these methods where turning down the tolerance isn't a valid solution to get AD better very chaotic systems or fluid dynamics systems where FP accuracy matters, or would some toy ODEs also break?

Some of the counter examples are non-stiff linear systems. You don't need chaos or anything of the sort to have such issues. The paper example is fully non-convergent, non-stiff, and linear for example. There can be a degree of how bad of an example you have too, so while that example is rather pathological (though again, linear and seemingly reasonable at first glance), you can also construct examples on a sliding scale of non-pathological to pathological where the adaptivity can do something but just leads to a local truncation error that is a constant C larger than the local truncation of the primal. The point really is that all values of C are possible with even linear examples, even C=inf.

SGD usually has approximations of the true gradient normally (and you can add noise/do all sort of weird stuff to gradients and still converge in ML), so would this meaningfully affect training. If it does, and training w/ reverse diff on systems such as this is a dangerous idea, we would need to turn down the tolerances, but as you remarked "not all methods (native autodiff of solvers especially) are convergent as tolerance goes to zero in all cases" this means we can't just decrease tolerance and fix the problem (universally).

Without proper numerical analysis and examples of the dynamics of inaccurate ODE gradients, I'll refrain from making any statement as to whether the discrete adjoint gradient is okay or more okay than the continuous adjoint when in the context of high error. There are some good reasons to suggest it's not though. For one, you do need to ensure your gradient is stable. Maybe a good example to add to the paper is a quick example of where it's not. If you do implicit euler going forwards then you have forward euler as the method approximating the gradient process in reverse. It's not a standard forward euler because you're then using the u values from the forward pass for vjp defining the odes, so in a sense it's "snapping back" to the original process in an odd way, but you do have a method whose local propagation is the "adjoint method of the ODE stepper" which is explicit Euler in that case. Can you construct examples where the implicit Euler pass forwards is stable for a choice of dt but the reverse pass is unstable with those dt's? Absolutely, I used to do that one in class as a nice example for thinking about the numerical analysis of derivatives, as it really highlights that the way the derivative is approximated really matters. The point though is that if the reverse propagation is numerically unstable, is it a useful gradient at all? I don't it's possible to make a statement about the utility of it in that case, since without stability any sense of approximation accuracy becomes unbouded.

But okay, if you have chosen a method for which at the given dts the backwards pass is stable, but you still have sufficiently high error, is this in a regime where you would expect discrete adjoints to have better error properties than continuous adjoints? Again you cannot guarantee you are in this case without error control, but it's plausible this case is common enough to consider. In that case, maybe? But there's still arguments against in. In the common cases of machine learning, you have quasi-static codes in the definition of https://www.stochasticlifestyle.com/useful-algorithms-that-are-not-optimized-by-jax-pytorch-or-tensorflow/, i.e. you have codes which may be control flow but have a flow result which is the same regardless of the input values. So for example, while you may have control flow, there exists a fully unrolled code that is the same code for all inputs that has no control flow. Control flow is just an abstraction in this kind of case. That's your standard machine learning case, and non-adaptive ODE solvers are in that case. In that case, you can say that the discrete adjoint is approximating not necessarily the true process but the discrete code, and it will do so in some manner where the errors are matching the errors of the discretization, and that can be okay.

But adaptive ODE solvers are not in that case. If you change the inputs or tolerances, you get in a sense a different code. It's not say fundamentally different, but you'll loop a different number of times, and this is a discontinuous change to the code. So as you are changing the tolerances, you're not just changing the dts because you're also changing the graph that you're differentiating. What are the properties of the discrete derivative with respect to the compute graph of the first in comparison to the discrete derivative of the compute graph with respect to the lower tolerances? In some sense, the compute graph of the lower tolerance case is almost like a continuous relaxation, so in theory I think something can be pieced together to do numerical analysis on this, but the naive answer is that they are different compute graphs and so every time the tolerance discuss in a way that changes a step (changes a step reject to an accept, changes the number of steps), you have a discontinuity, and it's not clear how to think about how those discontinuities effect the error of the process.

So given that we have counter examples to some cases that we can build, my only firm answer here is that the error propagation in automatic differentiation is much more subtle than many wish to sell it. The reason is because the common way of thinking about it is that, at each step of the process, we write down the jvp/vjp and the error is small, so therefore the total error is small. But that's the jump, "so therefore the total error is small". Numerical analysis is basically an entire discipline that puts a warning sign on that statement. LU/QR/SVD factorization is just *, +, and /, but if you do this enough then you have error growth based on the (square root of the) condition number, so little 1e-16 errors build up unless you change your process to be more stable (pivoting). In ODEs, you can have a small truncation error at every step but still get junk in the solution because of how the errors compound, this is stability. So, the errors of discrete adjoints are small at every step and match the properties of our discretized algorithm, and so therefore they should approximate the continuous system well. Nope, not really, there's linear counter examples and I believe there's a lot more that needs to be studied in this subject.

(An unrelated note is that I'm currently working through a project where it seems the gradient requires 128-bit precision as it loses 64-bits of precision due to floating point inaccuracies in the generated derivative process even though the primal has about ~6 ulp accuracy or about 10 digits of accuracy... yikes! So, automatic differentiation even when correct can have more interesting numerical issues than most people think. That of course is about a year away from publication, but it's a nice example to note that autodiff done correctly still has many ways to hit numerical issues which can be proven from its analytical form).

So anyways, in conclusion, right now from what we know I think the only firm advice we can give people is:

  1. Make sure the tolerance is sufficiently small as it can effect your gradient accuracy (though note tolerance is only local error)
  2. Though it does not effect the gradient accuracy of all methods globally on all examples so beware (and on examples where tolerance does effect the accuracy directly, like continuous adjoint methods or modified forward mode, it only sets a local not global tolerance so buyer beware there)
  3. The best way to confirm a relative accuracy size is to double check a few different ways of calculating the gradient.

And of course this has always been true with differentiation of ODEs so you can make these examples with CVODES, FATODE, etc. too, so don't shoot the messenger 😅

frankschae commented 3 weeks ago

Thanks! I think this makes sense now. I did another test where I fixed the solver steps to be the ones taken/selected by the stepsize controller/the adaptive routine with large tolerance. Then, the gradient is the same for all three methods. And this value is identical to the discrete adjoint applied on the adaptive solve.

function loss2(u0)
    _prob = remake(odeprob, u0=u0)
    _sol = (solve(_prob, Heun(),
        #dt=0.1,
        #abstol=0.1,
        #reltol=0.1,
        save_everystep=true,
        tstops = [0.1, 0.6139295972860885, 1.1997802572988525, 1.760825190439614, 2.2899305515863184, 2.7977642464962393, 3.0],
        save_start=false,
        adaptive=false,
        #controller=IController(), #CustomController(),
        sensealg=ReverseDiffAdjoint()
    ))
    @show (_sol.t)
    _sol = _sol[end]

    return sum(abs2, _sol)
end

begin
    println("FiniteDiff")
    grad1 = FiniteDiff.finite_difference_gradient(loss2, u0)
    println("Forward")
    grad2 = ForwardDiff.gradient(loss2, u0)
    println("Reverse")
    grad3 = Zygote.gradient(loss2, u0)[1]
    @show grad1 grad2 grad3
end
_sol.t = [0.1, 0.613929509825019, 1.1997800701371129, 1.7608249433819432, 2.289930281088923, 2.7977639535014687, 3.0]
_sol.t = [0.1, 0.6139295972860885, 1.1997802572988525, 1.760825190439614, 2.2899305515863184, 2.7977642464962393, 3.0]
_sol.t = [0.1, 0.6139284165641612, 1.1997779517656673, 1.7608221635535624, 2.2899271418305647, 2.7977606306650937, 3.0]
v2 = 2286.5798548666576
finite_diff(u0) = [4.671882288675988, 4825.2249315282825]
FiniteDiff
_sol.t = [0.1, 0.613929509825019, 1.199780114366179, 1.7608249943840808, 2.28993033847855, 2.797764017021789, 3.0]
_sol.t = [0.1, 0.6139295972860885, 1.1997803015279263, 1.7608252770250412, 2.2899306781168063, 2.7977643468051436, 3.0]
_sol.t = [0.1, 0.6139288538679462, 1.199778799113144, 1.7608232612500783, 2.2899284095313726, 2.7977619646944447, 3.0]
_sol.t = [0.1, 0.6139302969757168, 1.1997816219077029, 1.7608269783142547, 2.28993257339865, 2.7977663637197554, 3.0]
Forward
_sol.t = [0.1, 0.728028559619219, 1.4979699401846973, 2.275540178765593, 3.0]
Reverse
_sol.t = [0.1, 0.6139295972860885, 1.1997802572988525, 1.760825190439614, 2.2899305515863184, 2.7977642464962393, 3.0]
grad1 = [3.4720645693875793, 4831.212995065322]
grad2 = [-11.629850313906191, 3561.5179283293633]
grad3 = [-15.024234205512382, 4588.183943938828]

FiniteDiff
_sol.t = [0.1, 0.6139295972860885, 1.1997802572988525, 1.760825190439614, 2.2899305515863184, 2.7977642464962393, 3.0]
_sol.t = [0.1, 0.6139295972860885, 1.1997802572988525, 1.760825190439614, 2.2899305515863184, 2.7977642464962393, 3.0]
_sol.t = [0.1, 0.6139295972860885, 1.1997802572988525, 1.760825190439614, 2.2899305515863184, 2.7977642464962393, 3.0]
_sol.t = [0.1, 0.6139295972860885, 1.1997802572988525, 1.760825190439614, 2.2899305515863184, 2.7977642464962393, 3.0]
Forward
_sol.t = [0.1, 0.6139295972860885, 1.1997802572988525, 1.760825190439614, 2.2899305515863184, 2.7977642464962393, 3.0]
Reverse
_sol.t = [0.1, 0.6139295972860885, 1.1997802572988525, 1.760825190439614, 2.2899305515863184, 2.7977642464962393, 3.0]
grad1 = [-15.024234303085242, 4588.183943943744]
grad2 = [-15.024234205512375, 4588.183943938823]
grad3 = [-15.024234205512368, 4588.183943938822]
lockwo commented 3 weeks ago

I appreciate your detailed response, it's a good summary (maybe we can add it to a FAQ or something). I don't feel totally confident in my understanding, but I will pester Frank till I get all of it :)