patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 130 forks source link

Explosion of steps for specific parameter values #386

Open FFroehlich opened 8 months ago

FFroehlich commented 8 months ago

I have been experiencing odd integration failures in large sets of solves of relatively small simply systems of equations. I have narrowed this down to a small example:

import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

jax.config.update("jax_enable_x64", True)

a = jnp.array(
    [
        6.026932645397832,
        4.41195014234956,
        5.884199824299863,
        3.673504195449191,
        4.17957753821087,
    ]
)
b = jnp.float64(
    -2.823760940491063
)

def xdot(t, x, _):
    c = jnp.exp(a[1]) - x[0]
    d = x[1] / (c + jnp.exp(b))
    dxdt = jnp.asarray(
        (
            jnp.exp(a[3]) * d * c - jnp.exp(a[4]) * x[0],
            jnp.exp(a[3]) * (jnp.exp(a[0]) - x[1]),
        )
    )
    return dxdt

def xdot_lse(t, x, _):
    c = jnp.log(jnp.exp(a[1]) - x[0])
    d = jnp.log(x[1]) - jax.nn.logsumexp(jnp.array([b, c]))
    dxdt = jnp.asarray(
        (
            jnp.exp(a[3] + d + c) - jnp.exp(a[4]) * x[0],
            jnp.exp(a[3]) * (jnp.exp(a[0]) - x[1]),
        )
    )
    return dxdt

y0 = jnp.array([4.1154706432848185, 6.831774897154676])
ts = np.concatenate((np.array([0]), np.logspace(-6, 2, 20)))

sol = diffrax.diffeqsolve(
    diffrax.ODETerm(xdot_lse),
    solver=diffrax.Kvaerno5(),
    t0=0.0,
    t1=ts[-1],
    dt0=1e-8,
    y0=y0,
    stepsize_controller=diffrax.PIDController(
        atol=1e-8,
        rtol=1e-6,
        pcoeff=0.4,
        icoeff=0.3,
        dcoeff=0,
    ),
    max_steps=int(1e5),
    saveat=diffrax.SaveAt(ts=ts),
    throw=False,
)
x = sol.ys[-2, :]

for mag in np.logspace(-14, -4, 6):
    xx = np.linspace(-mag, mag, 200)
    fx_lse = jnp.asarray(
        [xdot_lse(0.0, x + jnp.asarray([eps, 0.0]), None)[0] for eps in xx]
    )
    fx = jnp.asarray(
        [xdot(0.0, x + jnp.asarray([eps, 0.0]), None)[0] for eps in xx]
    )
    f, axes = plt.subplots(1, 2)
    axes[0].plot(xx, fx_lse, marker='o', label='xdot_lse')
    axes[0].plot(xx, fx, marker='.', label='xdot')
    axes[0].set_ylabel('value of dx/dt[0] at x+eps')
    axes[0].set_xlabel('eps')
    axes[0].legend()
    axes[1].plot(xx, fx - fx_lse, marker='o')
    axes[1].set_ylabel('implementation difference dx/dt[0] at x+eps')
    axes[1].set_xlabel('eps')
    plt.tight_layout()
    plt.show()

The example should fail at the last time-point with about ~50k rejected steps and ~50k accepted steps. Minuscule changes to the parameters, e.g. changing the first entry in a from 6.026932645397832 to 6.02693264539783 allows the system to be solved in ~90 steps. This is odd as the systems is pretty close to a steady state when the integration fails and should be easy to integrate.

I initially thought this might be the result of some numerical instability, but I'm no longer convinced that this is the case. For example, changing d to d = x[1]/(c + jnp.exp(b[0])) (implemented in xdot) resolves the integration failure, but doesn't result in any appreciably improvement in numerical stability with which the right hand side can be evaluated (see plots generated at the end of the script). The magnitude of changes that I see are in the range of 1e-11 to 1e-12, which in my understanding shouldn't matter too much for the tolerances that I am using. Therefore, my conclusion is that I might be hitting some weird numerical edge-case.

FFroehlich commented 8 months ago

Turns out I was still using version 0.4.1 of diffrax in that project, this error does not reproduce under 0.5.0. However, I want to note that the integration would succeed in 0.4.1 after adding jax.debug.print("state: {state}", state=controller_state, ordered=True) at https://github.com/patrick-kidger/diffrax/blob/7f30854117d46c01045c0be67435edb3cdb5db74/diffrax/integrate.py#L253, which to me suggests a more serious issue. Will close for now and report back if I encounter something similar.

FFroehlich commented 8 months ago

Problem persists in 0.5.0 with slightly different parameter value, updated example above. Integration failure can still be "fixed" by adding jax.debug.print("state: {state}", state=controller_state, ordered=True) at https://github.com/patrick-kidger/diffrax/blob/d97ba2006426836b9b57dfed8d2c24c7373567e0/diffrax/_integrate.py#L274

patrick-kidger commented 8 months ago

Hmm, interesting. The fact that things change when you add jax.debug.print suggests that this is probably a floating-point thing.

Maybe try adding the debug statement at the very start or very end of the loop? I think it's much less likely that any optimisation/etc. passes are being applied to it there (JAX/XLA do far fewer optimisations across control flow boundaries), in which case that might offer a window into what's going on.

Another option for debugging might be to use SaveAt(steps=True), possibly in conjunction with varying t1 whilst looking at sol.stats.

FFroehlich commented 8 months ago

It only happens with a debug statement that involves the controller state.

FFroehlich commented 8 months ago

I have tried debugging this with steps=True and looking, but it hasn’t been particularly insightful. Any slight changes to the overall configuration tend to make the problem magically disappear, but in sets of thousands of ODE solves there are always tends to be a handful that fail (assuming there is only one problem)

Why would the debug statement point to a floating point issue?

patrick-kidger commented 8 months ago

So adding in a debug statement changes the nature of the program very little. The only thing it really changes is to require that the outputted value must exist -- and in particular, that its node in the computation graph cannot be optimised by the compiler. For example, given an integer a, then

b = a + 1
c = b - 1

would probably just get optimised down to c = a. However if we added in a jax.debug.print('{}', b), then we will in fact have to compute b, and so the computation has changed ever-so-slightly.

In practice, optimisations are of course meant not to change the behaviour of the program. So if they do, it's usually due to any of the many twiddly floating point gotchas that can subtly adjust results.

None of the above is 100% btw, it's more a rule of thumb.


If it's only the controller state that causes issues with debugging, then that sounds like we can still insert jax.debug.print statements for everything else, and in doing so mostly see what's going on?

FFroehlich commented 8 months ago

Right, so I've done that and I understand what is going on: the nonlinear diverges once and then appears to be stuck and repeatedly fails. However, I don't see why the nonlinear solve would fail, it's a pretty simple problem pretty close to steady state and the solver has been taking huge steps with small predicted errors before.

You can see an edited debugging output here using

Unfortunately, I don't seem to be able to use ordered=True in the print statements as it produces some error, but the printed statements appear to be well-ordered anyways. I have removed most of the early outputs as there really isn't anything interesting happening and only kept the statements from _integrate to illustrate the stepsize the solver was taking before the failure. I have also adapted max_steps=int(1e2) to keep the output manageable:

https://gist.github.com/FFroehlich/a7378fcba87af32d307894e43dd82367

patrick-kidger commented 8 months ago

If it's specifically the nonlinear solve that is doing odd things, then perhaps this is specifically an issue with Optimistix. (Or with VeryChord?) Would it be possible to extract the inputs to the nonlinear solve we make, and consider that in isolation?

One immediate possible culprit that comes to mind is that you might be right on the edge of this acceptance criteria:

https://github.com/patrick-kidger/diffrax/blob/main/diffrax/_root_finder/_verychord.py#L18-L22

which might potentially have tolerances that are still slightly too loose.

FFroehlich commented 7 months ago

Well, isolating the inputs to the nonlinear solve is ~a massive pain~ non-trivial since it requires reconstructing the nonlinear function which depends on butcher tableaus etc. The way that https://github.com/patrick-kidger/diffrax/blob/d97ba2006426836b9b57dfed8d2c24c7373567e0/diffrax/_solver/runge_kutta.py#L444 is implemented means that I would have to copy a ton of code to reconstruct the inputs to https://github.com/patrick-kidger/diffrax/blob/d97ba2006426836b9b57dfed8d2c24c7373567e0/diffrax/_solver/runge_kutta.py#L984 I tried, but lost interest after assembling ~100 lines of code that were scattered throughout the whole file and imported from elsewhere.

In the end, I am not really convinced that going through this exercise would prove particularly insightful. The non-linear solver diverges, which is bound to happen with newton schemes without globalisation strategy.

I suppose both termination and (lack of) globalisation strategies are largely motivated tribal knowledge and empirical choices based on a set of test problems that were established in the mathematics community decades ago. For example, sundials uses slightly different criteria https://github.com/LLNL/sundials/blob/2abd63bd6cbc354fb4861bba8e98d0b95d65e24a/src/cvodes/cvodes_nls.c#L303 compared to the conditions in diffrax, which seem to be based on the 1996 book by Wanner & Hairer. I haven't found any documentation/code what the Julia folks are using at the moment, but they appear to plan to give more flexibility through https://github.com/SciML/NonlinearSolve.jl.

patrick-kidger commented 7 months ago

That's fair! Reconstructing the inputs for that isn't super easy.

FWIW NonlinearSolve.jl is basically the equivalent of our Optimistix. In terms of globalisation mechanisms, you might find the relevant page of the Optimisix documentation interesting, as it discusses how Optimistix makes it possible to mix-and-match the various pieces of a globalisation mechanism. (Although we elected not to use the 'globalisation' terminology, which mostly just seems to promote confusion.)

Whilst such strategies aren't used by default in Diffrax, they are available to the advanced user. Diffrax will use any Optimistix root-finder you like (more precisely, anything implementing the optx.AbstractRootFinder interface), so you could implement one with a line search etc. if you wish. We've been meaning to add built-in support for this to optx.Newton, optx.Chord and diffrax.VeryChord -- analogous to what is already done in e.g. optx.GaussNewton -- and just haven't gotten around to it yet.

(If you think that's the appropriate solution to your problem, and are feeling sufficiently motivated, then we'd certainly be happy to take a PR on that!)

FFroehlich commented 7 months ago

Okay, I managed to consider the non-linear solve in isolation by starting integration at the problematic integration step.

The non-linear solver failures can be remedied by a variety of things, "re-starting" the solver by not passing the respective solver_state, changing the divergence check to be a bit more lenient (the newton method will converge in the next step).

I no longer think this is a bug, but rather the non-linear solver not being sufficiently robust for my purposes. It's probably the combination of large sample sizes and limited numerical accuracy of the right hand side. Close/At to steady-state, the newton step is effectively just noise, so the divergence check will just fail occasionally. In this case it seems to fail remarkably often even in successive integration steps, which might hint at some other structure to the problem that I am still missing. _small, which should likely handle such settings, checks for diffsize < 1e-13 when using double precision, which I think is too strict. diffsize is computed as ||diff / (atol + y_new * rtol)||, so for example with rtol=0.0 and atol=1e-8, this would check for ||diff||<1e-21, which does not seem reasonable.

I have re-implemented a convergence check that is similar (rate is not stored in state) to what is done in sundials

class SundialsChord(diffrax.VeryChord):
    def terminate(
        self,
        fn: callable,
        y,
        args,
        options,
        state,
        tags: frozenset[object],
    ):
        del fn, y, args, options, tags
        rate = state.diffsize / state.diffsize_prev
        factor = state.diffsize * jnp.nanmin(jnp.array([1.0, rate])) / 0.2
        converged = jnp.logical_and(state.step >= 2, factor < 1.0)
        diverged = jnp.logical_and(
            state.step >= 2, state.diffsize > 2.0 * state.diffsize_prev
        )
        terminate = diverged | converged
        terminate_result = optx.RESULTS.where(
            diverged | jnp.invert(converged),
            optx.RESULTS.nonlinear_divergence,
            optx.RESULTS.successful,
        )
        linsolve_fail = state.result != optx.RESULTS.successful
        result = optx.RESULTS.where(
            linsolve_fail, state.result, terminate_result
        )
        terminate = linsolve_fail | terminate
        return terminate, result

which solves the particular problem and uses a bit fewer steps, but doesn't appear to work well on other problems.

patrick-kidger commented 7 months ago

Nice! I'm really glad you got this working. :)