patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 130 forks source link

Perform root-finding to tune the final step of an integration with a `DiscreteTerminatingEvent` #227

Open cchapmanbird opened 1 year ago

cchapmanbird commented 1 year ago

A common feature in ODE solvers that feature event handling (e.g. scipy, DifferentialEquations.jl) is a root finding step that ensures the final step of the integration is placed on (or within tolerance, at least) the surface defined by the event. However, in Diffrax the integration is simply cut off after this boundary, placing a point wherever that happened to be.

Is this a planned feature? It would be very useful, especially in cases such as mine where having this finishing point finely resolved in time is very important.

edit: adding a plot to illustrate what I mean image

patrick-kidger commented 1 year ago

So this is the reason why the currently-supported events are referred to as "discrete" events.

"Continuous" events are a planned feature. If this is a priority for then it's one that I'd be happy to accept a PR on. Probably the best thing to do would be to rename DiscreteTerminatingEvent -> TerminatingEvent and have it accept an optional nonlinear solver argument; if passed it would try to find the exact termination time.

liuyangdh commented 1 year ago

So this is the reason why the currently-supported events are referred to as "discrete" events.

"Continuous" events are a planned feature. If this is a priority for then it's one that I'd be happy to accept a PR on. Probably the best thing to do would be to rename DiscreteTerminatingEvent -> TerminatingEvent and have it accept an optional nonlinear solver argument; if passed it would try to find the exact termination time.

Hi Patrick, I really like diffrax and am extensively testing it these days. I also encountered the same issue that the state upon event happening is not exact. Before, I used the odeint_event in torchdiffeq, where the event state is exact and the derivative wrt. input parameters is also exact (using methods from this paper). As I am switching to jax gradually, I am wondering if similar functionality as odeint_event in torchdiffeq is a future plan of diffrax?

patrick-kidger commented 1 year ago

Yes, it's a future plan.

dv-ai commented 1 year ago

Thank you for your great library. I am very interested by this feature. Do you have plan to implement it on near future?

I implement a very simple version to find the boundary thanks to a binary search using the interpolation method. I am very new on jax and diffrax and I know its not the best approach but if it could help someone:

    def f(t, x, args):
        y = x
        return y

    terms = diffrax.ODETerm(f)
    t0 = 0.0
    t1 = 1000
    y0 = jnp.array([1.0])
    dt0 = 0.0002
    saveat = diffrax.SaveAt(dense=True, t1=True)

    solver = diffrax.Kvaerno5()
    stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)  # rtol=1e-8, atol=1e-8

    class TerminatingEvent:

        def __init__(self, fun):
            self._fun = fun
            self.last_t_before_threshold = 0.0
            self.last_t_after_threshold = 0.0

        def set_last_t_before_threshold(self,last_t_before_threshold):
            self.last_t_before_threshold = last_t_before_threshold

        def set_last_t_after_threshold(self,last_t_after_threshold):
            self.last_t_after_threshold = last_t_after_threshold

        def fun(self, state, **kwargs):
            def false_fn():
                jax.experimental.io_callback(self.set_last_t_before_threshold, None, state.tprev)
                return False

            def true_fn():
                jax.experimental.io_callback(self.set_last_t_after_threshold, None, state.tprev)
                return True

            return jax.lax.cond(self._fun(state.y), true_fn, false_fn)

        def search_optimal_t(self, interpolation):
            def cond(state):
                low, high = state
                midpoint = 0.5 * (low + high)
                return (low < midpoint) & (midpoint < high)

            def body(state):
                low, high = state
                midpoint = 0.5 * (low + high)
                state = interpolation.evaluate(midpoint)
                result = self._fun(state)
                low = jnp.where(result, low, midpoint)
                high = jnp.where(result, midpoint, high)
                return (low, high)

            low_result, high_result = jax.lax.while_loop(cond, body, (self.last_t_before_threshold, self.last_t_after_threshold))
            solution = (low_result + high_result)/2.0
            return solution

    terminating_event = TerminatingEvent(lambda y: y[0] > 2.0)
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        t0,
        t1,
        dt0,
        y0,
        saveat=saveat,
        stepsize_controller=stepsize_controller,
        discrete_terminating_event=diffrax.DiscreteTerminatingEvent(terminating_event.fun),
        max_steps=100000
    )
    print("Times interval", terminating_event.last_t_before_threshold, terminating_event.last_t_after_threshold)
    result = terminating_event.search_optimal_t(sol.interpolation)
    print("computed result",result)
    print("expected result", jnp.log(2.0))
    print("error",jnp.abs(result-jnp.log(2.0)))

    ## Times interval 0.6906340910469938 0.8243209092563926
    ## computed result 0.6931471700829934
    ## expected result 0.6931471805599453
    ## error 1.047695186162656e-08

One of the issue of this implementation, is that require the dense=True of SaveAt which require to save a lot of state. It's possible to avoid the need to save the whole domain by first computed the diffeqsolve with dense=False and after that call again the diffeqsolve on the interval (terminating_event.last_t_before_threshold, terminating_event.last_t_after_threshold) with dense=True.

mbjd commented 12 months ago

It looks like you have a solution that works well for your case. I have a similar situation, found a different solution and am posting it here just in case someone is interested.

Basically, I have a vector field of the form d/dt [x, v] = [xdot, vdot], where vdot happens to always be positive (x is a vector and v a scalar). The goal is to integrate up to exactly when v hits some vmax, regardless of the value of t there. To do this I perform a change of variables, appending the trivial d/dt t = 1 to the original dynamics just to keep track of everything. Then I can just scale the vector field such that the new independent variable is v, giving us:

d/dv [x, v, t] = [xdot / vdot, vdot / vdot = 1, 1 / vdot]

obviously in this new form we have vdot = 1, so we may consider this our new independent variable (confusingly still named ts by diffrax). So I remove v from the state, and solve for [x, t] as a function of v. This makes it trivial to specify vmax as the integration boundary and we know its exact final value, without any kind of nonlinear solving required! (although granted with the additional requirement that the variable related to the termination event is monotonous). Also, be careful that the change of variables does not mess up the solution accuracy -- especially if vdot is small at some points or changes quickly this is a risk. In my case though the standard adaptive works wonderfully.