Open cchapmanbird opened 1 year ago
So this is the reason why the currently-supported events are referred to as "discrete" events.
"Continuous" events are a planned feature. If this is a priority for then it's one that I'd be happy to accept a PR on. Probably the best thing to do would be to rename DiscreteTerminatingEvent -> TerminatingEvent
and have it accept an optional nonlinear solver argument; if passed it would try to find the exact termination time.
So this is the reason why the currently-supported events are referred to as "discrete" events.
"Continuous" events are a planned feature. If this is a priority for then it's one that I'd be happy to accept a PR on. Probably the best thing to do would be to rename
DiscreteTerminatingEvent -> TerminatingEvent
and have it accept an optional nonlinear solver argument; if passed it would try to find the exact termination time.
Hi Patrick,
I really like diffrax and am extensively testing it these days. I also encountered the same issue that the state upon event happening is not exact. Before, I used the odeint_event
in torchdiffeq, where the event state is exact and the derivative wrt. input parameters is also exact (using methods from this paper). As I am switching to jax gradually, I am wondering if similar functionality as odeint_event
in torchdiffeq is a future plan of diffrax?
Yes, it's a future plan.
Thank you for your great library. I am very interested by this feature. Do you have plan to implement it on near future?
I implement a very simple version to find the boundary thanks to a binary search using the interpolation method. I am very new on jax and diffrax and I know its not the best approach but if it could help someone:
def f(t, x, args):
y = x
return y
terms = diffrax.ODETerm(f)
t0 = 0.0
t1 = 1000
y0 = jnp.array([1.0])
dt0 = 0.0002
saveat = diffrax.SaveAt(dense=True, t1=True)
solver = diffrax.Kvaerno5()
stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8) # rtol=1e-8, atol=1e-8
class TerminatingEvent:
def __init__(self, fun):
self._fun = fun
self.last_t_before_threshold = 0.0
self.last_t_after_threshold = 0.0
def set_last_t_before_threshold(self,last_t_before_threshold):
self.last_t_before_threshold = last_t_before_threshold
def set_last_t_after_threshold(self,last_t_after_threshold):
self.last_t_after_threshold = last_t_after_threshold
def fun(self, state, **kwargs):
def false_fn():
jax.experimental.io_callback(self.set_last_t_before_threshold, None, state.tprev)
return False
def true_fn():
jax.experimental.io_callback(self.set_last_t_after_threshold, None, state.tprev)
return True
return jax.lax.cond(self._fun(state.y), true_fn, false_fn)
def search_optimal_t(self, interpolation):
def cond(state):
low, high = state
midpoint = 0.5 * (low + high)
return (low < midpoint) & (midpoint < high)
def body(state):
low, high = state
midpoint = 0.5 * (low + high)
state = interpolation.evaluate(midpoint)
result = self._fun(state)
low = jnp.where(result, low, midpoint)
high = jnp.where(result, midpoint, high)
return (low, high)
low_result, high_result = jax.lax.while_loop(cond, body, (self.last_t_before_threshold, self.last_t_after_threshold))
solution = (low_result + high_result)/2.0
return solution
terminating_event = TerminatingEvent(lambda y: y[0] > 2.0)
sol = diffrax.diffeqsolve(
terms,
solver,
t0,
t1,
dt0,
y0,
saveat=saveat,
stepsize_controller=stepsize_controller,
discrete_terminating_event=diffrax.DiscreteTerminatingEvent(terminating_event.fun),
max_steps=100000
)
print("Times interval", terminating_event.last_t_before_threshold, terminating_event.last_t_after_threshold)
result = terminating_event.search_optimal_t(sol.interpolation)
print("computed result",result)
print("expected result", jnp.log(2.0))
print("error",jnp.abs(result-jnp.log(2.0)))
## Times interval 0.6906340910469938 0.8243209092563926
## computed result 0.6931471700829934
## expected result 0.6931471805599453
## error 1.047695186162656e-08
One of the issue of this implementation, is that require the dense=True of SaveAt which require to save a lot of state. It's possible to avoid the need to save the whole domain by first computed the diffeqsolve with dense=False and after that call again the diffeqsolve on the interval (terminating_event.last_t_before_threshold, terminating_event.last_t_after_threshold) with dense=True.
It looks like you have a solution that works well for your case. I have a similar situation, found a different solution and am posting it here just in case someone is interested.
Basically, I have a vector field of the form d/dt [x, v] = [xdot, vdot], where vdot happens to always be positive (x is a vector and v a scalar). The goal is to integrate up to exactly when v hits some vmax, regardless of the value of t there. To do this I perform a change of variables, appending the trivial d/dt t = 1 to the original dynamics just to keep track of everything. Then I can just scale the vector field such that the new independent variable is v, giving us:
d/dv [x, v, t] = [xdot / vdot, vdot / vdot = 1, 1 / vdot]
obviously in this new form we have vdot = 1, so we may consider this our new independent variable (confusingly still named ts
by diffrax). So I remove v from the state, and solve for [x, t] as a function of v. This makes it trivial to specify vmax as the integration boundary and we know its exact final value, without any kind of nonlinear solving required! (although granted with the additional requirement that the variable related to the termination event is monotonous). Also, be careful that the change of variables does not mess up the solution accuracy -- especially if vdot is small at some points or changes quickly this is a risk. In my case though the standard adaptive works wonderfully.
A common feature in ODE solvers that feature event handling (e.g. scipy, DifferentialEquations.jl) is a root finding step that ensures the final step of the integration is placed on (or within tolerance, at least) the surface defined by the event. However, in Diffrax the integration is simply cut off after this boundary, placing a point wherever that happened to be.
Is this a planned feature? It would be very useful, especially in cases such as mine where having this finishing point finely resolved in time is very important.
edit: adding a plot to illustrate what I mean