patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.42k stars 127 forks source link

Inconsistency between constant and adaptive step size solvers with Discrete terminating events #416

Open ivandariojr opened 5 months ago

ivandariojr commented 5 months ago

When using the adaptive step size solver to integrate an ODE with a discrete terminating event, the detected event time is significantly after the detected event time computed with a constant step size solver. This error is significant for my application and above all tolerances I set. Is this expected behavior?

I am attaching a plot to illustrate the problem. You can see that the adaptive step solver continues integrating long after the event should be triggered (computed analytically).

image

And a minimum working example:


import diffrax as de
import jax.numpy as jnp

n = 2
m = 1
mass = 1.0
l = 1.0
gravity = 9.81
damping = 0.0
cost_threshold = 1e0
def f(x):
    theta, theta_dot = x[..., 0], x[..., 1]
    damp = (-theta_dot * damping) / (mass * l ** 2)
    f_1 = theta_dot
    f_2 = damp + (gravity/l)*jnp.sin(theta)
    return jnp.stack([f_1, f_2], axis=-1)

def g(x):
    return jnp.array([[0.0], [1.0/(mass * l**2)]])

def cost(x):
    return jnp.sum(x**2)

def vf(t, y_rc, args):
    y, rc = y_rc
    y_dot = f(y) + g(y) @ jnp.array([10.0])
    rc_dot = cost(y)
    return y_dot, rc_dot

def terminating_event(state, **kwargs):
    y_rc = state.y
    y, rc = y_rc
    return rc > cost_threshold

ts = jnp.linspace(0, 10, 2**13)
adaptive_solution = de.diffeqsolve(
    de.ODETerm(vf),
    solver=de.Tsit5(),
    t0=ts[0],
    t1=ts[-1],
    dt0=ts[1]-ts[0],
    saveat=de.SaveAt(ts=ts),
    y0=(jnp.array([-jnp.pi, 0.0]), jnp.array(0.0)),
    stepsize_controller=de.PIDController(rtol=1e-6, atol=1e-12, dtmin=1e-5),
    discrete_terminating_event=de.DiscreteTerminatingEvent(terminating_event),
    max_steps=8192
)

constant_solution = de.diffeqsolve(
    de.ODETerm(vf),
    solver=de.Tsit5(),
    t0=ts[0],
    t1=ts[-1],
    dt0=ts[1]-ts[0],
    saveat=de.SaveAt(ts=ts),
    y0=(jnp.array([-jnp.pi, 0.0]), jnp.array(0.0)),
    stepsize_controller=de.ConstantStepSize(),
    discrete_terminating_event=de.DiscreteTerminatingEvent(terminating_event),
    max_steps=8192
)

finite_xor = jnp.isfinite(constant_solution.ys[1]) ^ jnp.isfinite(adaptive_solution.ys[1])
adaptive_fail_ts = finite_xor * ts
non_zeros = jnp.nonzero(adaptive_fail_ts)[0]
fail_ts = adaptive_fail_ts[non_zeros]
first_event = ts[non_zeros[0]-1]
second_event = ts[non_zeros[-1]]
print(f'First event: {first_event}')
print(f'Second event: {second_event}')
print(f"Event time difference: {second_event - first_event}")

P.S.: Thanks so much for your work on diffrax. It has been invaluable for my research.

patrick-kidger commented 5 months ago

Hey there! I'm glad you're enjoying Diffrax :)

This is expected behaviour -- the "discrete" in "discrete terminating event" means that it does not attempt to find the root of the event function. Rather, it simply stops integrating when the event triggers over the course on an individual numerical step. This approach is useful when you don't care precisely when something occured (e.g. when integrating to a steady state) and don't need to spend extra computation resolving the exact time.

In this case, the fact that the constant step size got much closer is just a fluke.

If you do need the exact location then we intend to add this as feature shortly. Check out the feature branch #387 and give it a go -- we'd love to get some feedback on it!