patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.34k stars 123 forks source link

Steady-state event occurs immediately if state with large value constant #448

Open johannahaffner opened 1 month ago

johannahaffner commented 1 month ago

Hi Patrick,

I noticed that steady state events occur much earlier if a state with a large value is constant. This behavior occurs irrespective of the optx.norm specified, even though I would assume it not to occur when using optx.max_norm. (So it looks like an additional concern, even if it is related to #131).

Here is some code that reproduces this:

import jax.numpy as jnp
import diffrax as dfx
import optimistix as optx

def dydt(t, y, k):
    return -k * y

# Scenario A: one state constant at high value
k0_a = jnp.array([0., 1.]) 
y0_a = jnp.array([1e05, 5.])

# Scenario B: only dynamic state, rate and value identical
k0_b = jnp.array([1.])
y0_b = jnp.array([5.])

# Steady state reached earlier in scenario A, holds for all norms
norms = [optx.max_norm, optx.two_norm, optx.rms_norm]
for norm in norms:
    steady_state = dfx.SteadyStateEvent(atol=1e-06, rtol=1e-03, norm=norm)

    # Scenario A
    solution_a = dfx.diffeqsolve(
        dfx.ODETerm(dydt),
        dfx.Tsit5(),
        0, 100., 0.1, y0_a, args=k0_a,
        discrete_terminating_event=steady_state,
    )
    assert solution_a.result == dfx.RESULTS.discrete_terminating_event_occurred

    # Scenario B: steady state reached much later
    solution_b = dfx.diffeqsolve(
        dfx.ODETerm(dydt),
        dfx.Tsit5(),
        0, 100., 0.1, y0_b, args=k0_b,
        discrete_terminating_event=steady_state,
    )
    assert solution_b.result == dfx.RESULTS.discrete_terminating_event_occurred

    # Compare the times
    print(solution_a.ts, solution_b.ts)

Changing the tolerances will change the time points at which the event occurs, but the system will still reach a (spurious) steady state in scenario A.

I use these in practice to find steady-states of models representing biological systems, where state values can be wildly different, but "small" states can still exert great influence on other states. (I have two large constant states because I set their growth-rates to zero while trying to get rid of parameter-dependent transients on the small states.)

johannahaffner commented 1 month ago

I did some digging and now understand that this is actually expected behavior, since the steady-state event uses the full system state to check the condition (in SteadyStateEvent).

86        vf = solver.func(terms, state.tprev, state.y, args)
87       return self.norm(vf) < _atol + _rtol * self.norm(state.y)

I had assumed that it uses the initial (or maximum) rate of change to compare, for each state. Could we introduce an element-wise comparison?

johannahaffner commented 1 month ago

Oh, damn. I need to take a break :D

The way to flip this is by setting the relative tolerances way ~higher~[lower] than the absolute ones.

This issue can be closed, since it isn't actually an issue.

Edit: negative exponent of specified tolerance needs to be (substantially) higher - resulting in lower tolerance.

patrick-kidger commented 1 month ago

Ah, this is actually an interesting point! I've been aware of this wart but I'm not completely sure how to fix it. Pretty much everywhere else we do actually compute the norm of a ratio, for precisely this reason.

The problem is that in general the 'vector field space' and the 'y space' are different things: they're related to each other by the bilinear transformation AbstractTerm.prod.

So for the ODE we can divide y / f (and check that this is large, note that I've flipped this around from the usual f / y), but for an SDE I'm guessing the closest equivalent would be g^{-1} y, in which we have to solve a linear system. (Which is why I flipped around the ODE, since I don't know of a way to find an analogy for "g / y".)

Maybe the correct solution here really is to bite the bullet and call lineax.linear_solve? I'm not sure.

johannahaffner commented 3 weeks ago

Hi!

I'm sorry for not getting back to you sooner - I did take that break.

On the ODE side of things, the norm of a ratio would be more intuitive. One wouldn't have to think about what the tolerances do, specifically. An absolute tolerance for a ratio would also be meaningless.

I hadn't considered the SDE case, thank you for pointing that out! Could we assume that g would always be invertible, though? That seems like a strong requirement.

Is that why you refer to a bullet being bitten? A linear solve does not seem too expensive for me, but then again the systems I am dealing with are not huge.

patrick-kidger commented 3 weeks ago

I think the bullet is just figuring out how to implement it! I agree that a linear solve probably isn't that expensive a lot of the time. (At the very least, I think it's better to default to slow-and-correct.)

Indeed the diffusion won't be invertible in general. I'm guessing that using the pseudoinverse is probably fine in that case. (?)

johannahaffner commented 1 week ago

Ah sorry, I missed this notification. And thank you for the clarification!

I'll keep this in the back of my head, I have quite a few things on my plate at the moment and given that there is a workaround for now, that is good enough at the moment :)

johannahaffner commented 1 week ago

Ps: I did take a very brief look at it. I'm not sure if the ratio approach would be general enough if we keep looking at a single time point t. For exponential decay, the rate of change is always the same, because f = -y.

Edit: this example does not have an analytical solution, but at least in practice, I usually only care about rates of change that are negligible according to some defined threshold.