patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.42k stars 127 forks source link

Adaptive solver ignores `step_ts` for impulse discontinuous ODE #211

Open KeAWang opened 1 year ago

KeAWang commented 1 year ago

I'm trying to integrate a jump discontinuous ODE using an adaptive solver (i.e. discontinuous at a specific timepoint).

However, the integration seems to ignore my step_ts:

import diffrax

t0, t1 = 0, 3
jump = 0.5  #  let's jump at this point
step_ts = [t0, jump, t1]

def vector_field(t, y, args):
    return -1.0 * y + 100.0 * (t == jump)  # also tried jnp.allclose but doesn't work either

term = diffrax.ODETerm(vector_field)
solver = diffrax.Tsit5()
saveat = diffrax.SaveAt(ts=step_ts)
stepsize_controller = diffrax.PIDController(
    pcoeff=0.0,
    icoeff=1.0,
    dcoeff=0.0,
    rtol=1e-4,
    atol=1e-6,
    step_ts=step_ts,
)

sol = diffrax.diffeqsolve(
    term,
    solver,
    t0=t0,
    t1=t1,
    dt0=None,
    y0=1,
    saveat=saveat,
    stepsize_controller=stepsize_controller,
)

print(sol.ts)
assert jump in sol.ts
print(sol.ys)

Output:

[0.  0.5 3. ]
[1.         0.6068427  0.04981759]  # should be something much larger since we have a large jump at t=0.5

I'm avoiding jump_ts based on https://github.com/patrick-kidger/diffrax/issues/58, which addressed the case where you have piecewise constant vector fields. However I'm not quite sure what to do when there's a delta-function in the vector field (impulse external control)

Is there a way to do this in diffrax?

jaschau commented 1 year ago

Hi, I think what you are trying to do here is not really well-defined. A delta distribution is not something you'll be able to handle directly in any numerical code. Of course, you can do the standard engineering trick and integrate the differential equation in an epsilon environment around the jump $t0$ which gives $$\int{t_0 - \epsilon}^{t_0 + \epsilon} \dot y dt = y(t_0 + \epsilon) - y(t0 - \epsilon) = - \int{t_0 - \epsilon}^{t0 + \epsilon} y dt + 100 \int{t_0 - \epsilon}^{t_0 + \epsilon} \delta(t - t_0) dt \xrightarrow{\epsilon \rightarrow 0} 100 $$ and tells you that $y$ will jump by 100 after passing $t_0$. So what you could do is to integrate the differential equation without the delta distribution numerically until $t_0$ and then continue the integration from $t_0$ onwards with the initial condition given by whatever the first integration produced + 100.

patrick-kidger commented 1 year ago

As @jaschau suggests, probably the simplest approach is to just split your problem into two separate diffeqsolves. (Actually a lax.scan over diffeqsolves would be slightly better, as that'll reduce compilation time.) And then apply your desired impulse between the two solves. It's expected that using step_ts wouldn't result in the delta function being noticed, I'm afraid -- what you're trying to do isn't within the normal remit of differential equation solvers.

Other notes:

[Also, wait, we can use LaTeX on GitHub now? Hurrah!]

KeAWang commented 1 year ago

That makes sense! Thank you to both of you :) Would be great for diffrax to have non-terminating event handling, especially for the simple case where the timestamps are known. Unfortunately won't be able to code up the feature now but maybe in the future!