patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.4k stars 124 forks source link

Event and PIDController: event doesn't always occure #507

Open dv-ai opened 1 week ago

dv-ai commented 1 week ago

First, I want to thank you for your amazing library. You have done a massive work which are very useful for my research.


diffrax 0.6.0 optimistix 0.0.7 jax 0.4.30

When using PID controler and Event functionnality are used simultaneously, I found that the event will not always raised due to the difference of complexity between the event function and the ode function. For example, If the ode function si very simple (straight line), the pid controler will allow to have large steps. Large step can miss the event condition if there is two changes of signs between this two steps.

To avoid this issue, I found that integrate the event condition in the ode function can correct the issue in my particular usecase. Is there a more mathematical grounded method to resolve this issue?

A small python example:

import jax
import jax.numpy as jnp
import jax.random as random

jax.config.update("jax_enable_x64", True)

import diffrax
import optimistix as optx

def event(x, high_pres=False,coeff=1.0):
    x = jnp.concatenate([jnp.ones((1, 1)) * x, jnp.zeros((1, 1))], axis=1)

    #  event_condition = (2.5 - y) . (5.0 - y) . (0.0 - y)
    event_condition = lambda t, y, args, **kwargs: (-y[0, 0] + 2.5) * (-y[0, 0] + 5.0) * (-y[0, 0] + 0.0)

    # dx/dt = [1, 0]
    # adding the event condition: dx/dt = [1, 0, | event_condition(t, x) | * coeff ]
    # coeff used to defined if the event condition is taken into account or not  
    ode_fun = lambda t, x_, args: jnp.concatenate([jnp.ones((1,1)), jnp.zeros((1,1)), coeff * jnp.expand_dims(jnp.expand_dims(jnp.abs(event_condition(t, x_, args)),axis=0),axis=0)],axis=1)

    fun = diffrax.ODETerm(ode_fun)

    if high_pres:
        stepsize_controller = diffrax.PIDController(rtol=1E-14, atol=1E-14)
    else:
        stepsize_controller = diffrax.PIDController(rtol=1E-8, atol=1E-8)

    solver = diffrax.Dopri8() 
    root_finder = optx.Bisection(1E-10, 1E-10)

    t1 = 10

    sol = diffrax.diffeqsolve(
        fun,
        solver,
        0.0,
        t1,
        None,
        jnp.concatenate([x, jnp.zeros((1,1))],axis=1),
        stepsize_controller=stepsize_controller,
        max_steps=None, 
        event=diffrax.Event(event_condition, root_finder),
        throw=False
    )

    event_occurred = diffrax.RESULTS.event_occurred == sol.result

    t_result = sol.ts
    x_last = sol.ys[0][:,:2]

    print(event_occurred, x_last, t_result)
    print(sol.result)

event(0.5, high_pres=True, coeff=0.0) # -> with high precision (1E-14), the event is detected
event(0.5, high_pres=False, coeff=0.0) # -> with "low" precision (1E-8), the event is not detected
event(0.5, high_pres=False, coeff=1.0) # -> with "low" precision (1E-8) and take into account the event condition on the ode terms, the event is detected
patrick-kidger commented 1 week ago

This is a classical problem with event handling.

Some ODE solvers attempt to reduce the impact of this by evaluating the event function e.g. 10 times every step, equally spaced apart. (You can still get the same issue of course, it just has to occur at a smaller resolution. As such Diffrax doesn't do this.)

I recommend picking an event function that changes sign only once in your region of integration. Unfortunately there's no smarter way to handle this in general.

dv-ai commented 1 week ago

Thank you for your quick answer. After I meet my issue, I realize that, if nothing was done in the solver part, the event will not always detect.

I am not familiar with the event ode litterature, but I think that some author demonstrate some guaranties to detect the event for modern solver with fixed time such "Shampine, L. F., & Thompson, S. (2001). Event location for ordinary differential equations. Computers & Mathematics with Applications, 42(1-2), 85-93." Other reference: Reliable solution of special event location problems for ODEs

For my understanding, they modify the original ode systems (it was my intuition) to obtain guaranties. The modification is quiet obvious by introducing the total derivative of the event function $S$ in the ode:
$\frac{dy(t)}{dt} = F(y(t),t)$ $\frac{dz(t)}{dt} = \frac{\partial S(y(t),t)}{ \partial t} + \nabla_y S(y(t),t) . F(y(t),t)$ $y(0) =x$ $z(0) = S(x,0)$

I think if diffrax doesn't integrate this kind of technics, it could be described in the documention. Because, at first when I don't know this issue, I was thinking that the error of detection will be related to the error of the ode solver and that diffrax manages that.

patrick-kidger commented 1 week ago

Right, so this is a nice idea! It 'slows down' the integration to match the rate at which the event function varies.

Unfortunately this requires that the event funtion be real-valued and differentiable. Neither of these things are necessarily true in the general case that we support in Diffrax.

That said I could see us perhaps adding something like an Event(..., reliable=True) flag that would add this auxiliary equation on an opt-in basis. I'd be happy to take a PR on that. (Also tagging @cholberg for interest.)