Open FFroehlich opened 8 months ago
Turns out I was still using version 0.4.1 of diffrax in that project, this error does not reproduce under 0.5.0. However, I want to note that the integration would succeed in 0.4.1 after adding jax.debug.print("state: {state}", state=controller_state, ordered=True)
at https://github.com/patrick-kidger/diffrax/blob/7f30854117d46c01045c0be67435edb3cdb5db74/diffrax/integrate.py#L253, which to me suggests a more serious issue. Will close for now and report back if I encounter something similar.
Problem persists in 0.5.0
with slightly different parameter value, updated example above. Integration failure can still be "fixed" by adding jax.debug.print("state: {state}", state=controller_state, ordered=True)
at https://github.com/patrick-kidger/diffrax/blob/d97ba2006426836b9b57dfed8d2c24c7373567e0/diffrax/_integrate.py#L274
Hmm, interesting. The fact that things change when you add jax.debug.print
suggests that this is probably a floating-point thing.
Maybe try adding the debug statement at the very start or very end of the loop? I think it's much less likely that any optimisation/etc. passes are being applied to it there (JAX/XLA do far fewer optimisations across control flow boundaries), in which case that might offer a window into what's going on.
Another option for debugging might be to use SaveAt(steps=True)
, possibly in conjunction with varying t1
whilst looking at sol.stats
.
It only happens with a debug statement that involves the controller state.
I have tried debugging this with steps=True and looking, but it hasn’t been particularly insightful. Any slight changes to the overall configuration tend to make the problem magically disappear, but in sets of thousands of ODE solves there are always tends to be a handful that fail (assuming there is only one problem)
Why would the debug statement point to a floating point issue?
So adding in a debug statement changes the nature of the program very little. The only thing it really changes is to require that the outputted value must exist -- and in particular, that its node in the computation graph cannot be optimised by the compiler. For example, given an integer a
, then
b = a + 1
c = b - 1
would probably just get optimised down to c = a
. However if we added in a jax.debug.print('{}', b)
, then we will in fact have to compute b
, and so the computation has changed ever-so-slightly.
In practice, optimisations are of course meant not to change the behaviour of the program. So if they do, it's usually due to any of the many twiddly floating point gotchas that can subtly adjust results.
None of the above is 100% btw, it's more a rule of thumb.
If it's only the controller state that causes issues with debugging, then that sounds like we can still insert jax.debug.print
statements for everything else, and in doing so mostly see what's going on?
Right, so I've done that and I understand what is going on: the nonlinear diverges once and then appears to be stuck and repeatedly fails. However, I don't see why the nonlinear solve would fail, it's a pretty simple problem pretty close to steady state and the solver has been taking huge steps with small predicted errors before.
You can see an edited debugging output here using
jax.debug.print("### diffsize: {diffsize}, diffsize_prev: {diffsize_prev}, rate: {rate}, factor: {factor}, small: {small}, diverged: {diverged}, converged: {converged}", diffsize=state.diffsize, diffsize_prev=state.diffsize_prev, rate=rate, factor=factor, small=small, diverged=diverged, converged=converged)
at https://github.com/patrick-kidger/diffrax/blob/d97ba2006426836b9b57dfed8d2c24c7373567e0/diffrax/_root_finder/_verychord.py#L165jax.debug.print("## stage_index={stage_index}, result={result}", stage_index=stage_index, result=result)
at https://github.com/patrick-kidger/diffrax/blob/d97ba2006426836b9b57dfed8d2c24c7373567e0/diffrax/_solver/runge_kutta.py#L1058jax.debug.print("# tprev: {tprev}, tnext:{tnext}, y: {y}, y_error: {y_error}, state: {state}, result: {result}", tprev=state.tprev, tnext=state.tnext, y=y, y_error=y_error, result=solver_result, state=state.solver_state)
at https://github.com/patrick-kidger/diffrax/blob/d97ba2006426836b9b57dfed8d2c24c7373567e0/diffrax/_integrate.py#L249Unfortunately, I don't seem to be able to use ordered=True
in the print statements as it produces some error, but the printed statements appear to be well-ordered anyways.
I have removed most of the early outputs as there really isn't anything interesting happening and only kept the statements from _integrate
to illustrate the stepsize the solver was taking before the failure. I have also adapted max_steps=int(1e2)
to keep the output manageable:
https://gist.github.com/FFroehlich/a7378fcba87af32d307894e43dd82367
If it's specifically the nonlinear solve that is doing odd things, then perhaps this is specifically an issue with Optimistix. (Or with VeryChord
?) Would it be possible to extract the inputs to the nonlinear solve we make, and consider that in isolation?
One immediate possible culprit that comes to mind is that you might be right on the edge of this acceptance criteria:
https://github.com/patrick-kidger/diffrax/blob/main/diffrax/_root_finder/_verychord.py#L18-L22
which might potentially have tolerances that are still slightly too loose.
Well, isolating the inputs to the nonlinear solve is ~a massive pain~ non-trivial since it requires reconstructing the nonlinear function which depends on butcher tableaus etc. The way that https://github.com/patrick-kidger/diffrax/blob/d97ba2006426836b9b57dfed8d2c24c7373567e0/diffrax/_solver/runge_kutta.py#L444 is implemented means that I would have to copy a ton of code to reconstruct the inputs to https://github.com/patrick-kidger/diffrax/blob/d97ba2006426836b9b57dfed8d2c24c7373567e0/diffrax/_solver/runge_kutta.py#L984 I tried, but lost interest after assembling ~100 lines of code that were scattered throughout the whole file and imported from elsewhere.
In the end, I am not really convinced that going through this exercise would prove particularly insightful. The non-linear solver diverges, which is bound to happen with newton schemes without globalisation strategy.
I suppose both termination and (lack of) globalisation strategies are largely motivated tribal knowledge and empirical choices based on a set of test problems that were established in the mathematics community decades ago. For example, sundials uses slightly different criteria https://github.com/LLNL/sundials/blob/2abd63bd6cbc354fb4861bba8e98d0b95d65e24a/src/cvodes/cvodes_nls.c#L303 compared to the conditions in diffrax, which seem to be based on the 1996 book by Wanner & Hairer. I haven't found any documentation/code what the Julia folks are using at the moment, but they appear to plan to give more flexibility through https://github.com/SciML/NonlinearSolve.jl.
That's fair! Reconstructing the inputs for that isn't super easy.
FWIW NonlinearSolve.jl
is basically the equivalent of our Optimistix. In terms of globalisation mechanisms, you might find the relevant page of the Optimisix documentation interesting, as it discusses how Optimistix makes it possible to mix-and-match the various pieces of a globalisation mechanism. (Although we elected not to use the 'globalisation' terminology, which mostly just seems to promote confusion.)
Whilst such strategies aren't used by default in Diffrax, they are available to the advanced user. Diffrax will use any Optimistix root-finder you like (more precisely, anything implementing the optx.AbstractRootFinder
interface), so you could implement one with a line search etc. if you wish. We've been meaning to add built-in support for this to optx.Newton
, optx.Chord
and diffrax.VeryChord
-- analogous to what is already done in e.g. optx.GaussNewton
-- and just haven't gotten around to it yet.
(If you think that's the appropriate solution to your problem, and are feeling sufficiently motivated, then we'd certainly be happy to take a PR on that!)
Okay, I managed to consider the non-linear solve in isolation by starting integration at the problematic integration step.
The non-linear solver failures can be remedied by a variety of things, "re-starting" the solver by not passing the respective solver_state, changing the divergence check to be a bit more lenient (the newton method will converge in the next step).
I no longer think this is a bug, but rather the non-linear solver not being sufficiently robust for my purposes. It's probably the combination of large sample sizes and limited numerical accuracy of the right hand side. Close/At to steady-state, the newton step is effectively just noise, so the divergence check will just fail occasionally. In this case it seems to fail remarkably often even in successive integration steps, which might hint at some other structure to the problem that I am still missing. _small
, which should likely handle such settings, checks for diffsize < 1e-13
when using double precision, which I think is too strict. diffsize
is computed as ||diff / (atol + y_new * rtol)||
, so for example with rtol=0.0
and atol=1e-8
, this would check for ||diff||<1e-21
, which does not seem reasonable.
I have re-implemented a convergence check that is similar (rate is not stored in state) to what is done in sundials
class SundialsChord(diffrax.VeryChord):
def terminate(
self,
fn: callable,
y,
args,
options,
state,
tags: frozenset[object],
):
del fn, y, args, options, tags
rate = state.diffsize / state.diffsize_prev
factor = state.diffsize * jnp.nanmin(jnp.array([1.0, rate])) / 0.2
converged = jnp.logical_and(state.step >= 2, factor < 1.0)
diverged = jnp.logical_and(
state.step >= 2, state.diffsize > 2.0 * state.diffsize_prev
)
terminate = diverged | converged
terminate_result = optx.RESULTS.where(
diverged | jnp.invert(converged),
optx.RESULTS.nonlinear_divergence,
optx.RESULTS.successful,
)
linsolve_fail = state.result != optx.RESULTS.successful
result = optx.RESULTS.where(
linsolve_fail, state.result, terminate_result
)
terminate = linsolve_fail | terminate
return terminate, result
which solves the particular problem and uses a bit fewer steps, but doesn't appear to work well on other problems.
Nice! I'm really glad you got this working. :)
I have been experiencing odd integration failures in large sets of solves of relatively small simply systems of equations. I have narrowed this down to a small example:
The example should fail at the last time-point with about ~50k rejected steps and ~50k accepted steps. Minuscule changes to the parameters, e.g. changing the first entry in
a
from6.026932645397832
to6.02693264539783
allows the system to be solved in ~90 steps. This is odd as the systems is pretty close to a steady state when the integration fails and should be easy to integrate.I initially thought this might be the result of some numerical instability, but I'm no longer convinced that this is the case. For example, changing
d
tod = x[1]/(c + jnp.exp(b[0]))
(implemented inxdot
) resolves the integration failure, but doesn't result in any appreciably improvement in numerical stability with which the right hand side can be evaluated (see plots generated at the end of the script). The magnitude of changes that I see are in the range of 1e-11 to 1e-12, which in my understanding shouldn't matter too much for the tolerances that I am using. Therefore, my conclusion is that I might be hitting some weird numerical edge-case.