patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.47k stars 134 forks source link

[Feature request] A delay differential equations solver #406

Open miguelgondu opened 7 months ago

miguelgondu commented 7 months ago

Dear Patrick,

Thanks for this library, it's pretty neat!

I'm currently supervising a bachelor's students thesis on NeuralODEs, and we've been meaning to use diffrax to study delay differential equations. I'm currently trying to implement a delay differential equations solver inside diffrax, and I would like to know if I'm in the right track.

For context, let me give a brief overview of how delay differential equations work, and how such a solver could be implemented. In it's simplest form, a (constant) Delay Differential Equation (DDE) has a vector field $f$ that depends not only on the current state $y(t)$, but also on $y(t-\tau)$ where $\tau\in\mathbb{R}_{>0}$. In other words

$$y'(t) = f(t, y(t), y(t-\tau)).$$

Initial value problems involving DDEs provide a history instead of a single initial value ($y(t) = \phi(t)$ for $t \leq 0$, for example), and are solved in chunks using the "method of steps". Shortly put, one solves an IVP in intervals of the form $[t_0 + k\tau, t_0 + (k+1)\tau]$. (More details in Chap. 9 of this reference).

In practice, solving a DDE numerically can be done by selecting the right step-size such that $y(t-\tau)$ is always in the grid. To predict y_{t+1} we need to evaluate the vector field in y_{t} and in some y_{t-k} corresponding to $y(t-\tau)$. Other ways of doing it would involve e.g. building an Hermite interpolation between the two relevant points in the grid if $y(t-\tau)$ happens to lie outside of the grid, but I plan to focus on the first alternative.

How could I adapt diffrax to let me pass terms of the form f(t, y(t), y(t-tau))dt? I imagine I have to implement a new DelayTerm that inherits from AbstractTerm with a different vf method; since solvers need to evaluate these vector fields, I imagine I would also need to modify/create a new one in which vf is called with the right signature, right?

I'm of course happy to contribute the implementation to diffrax once it's up and running.

patrick-kidger commented 7 months ago

So support for DDEs is something we've been noodling over in Diffrax for a while, see #169.

The main reason that PR stalled is that solving general DDEs requires solving several of nonlinear optimisation problems, and at the time Optimistix did not exist yet.

Now that it does we have been meaning to revisit that PR, fix it up to use the new root-finding functionality that is now available in Optimistix.

I must acknowledge that this is (a) fairly technical code, but also conversely (b) that the hard parts are already written.

If you'd be interested in reviving that PR then this is still a feature I'd be happy to see in Diffrax.

thibmonsel commented 7 months ago

Hello there,

As mentionned by @patrick-kidger, most of the code itself is there (I would say 90%) and functional but some jax related bugs are still there (e.g. tracer leakage) that makes backpropagation an issue. Happy to discuss if you are interested in giving a hand.

miguelgondu commented 7 months ago

Hi both,

Thanks for the implementation, @thibmonsel! My student and I have been using your dde.ipynb example on a different system of DDEs, and it worked almost out-of-the-box. If I understand correctly, the neuraldde.ipynb example is not finished yet, right?

I would be happy to help, but I fear this is above my skill level (I'm only now starting to use jax). If you give me pointers on how to get started, I could give it a try, but I can't promise much.

thibmonsel commented 7 months ago

That's great to here ! Integrating DDEs itself should be more than robust (so dde.ipynb should work fine). However, fitting a DDE with a neural net will shootout some Exception: Leaked trace when combined with jax.check_tracer_leaks().

A MWE for this could be :

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np

import diffrax

class Func(eqx.Module):
    linear: eqx.nn.Linear

    def __init__(self, data_size, *, key, **kwargs):
        super().__init__(**kwargs)
        self.linear = eqx.nn.Linear(2 * data_size, data_size, key=key)

    def __call__(self, t, y, args, *, history):
        return self.linear(jnp.hstack([y, *history]))

class NeuralDDE(eqx.Module):
    func: Func
    delays: diffrax.Delays

    def __init__(self, data_size, delays, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, key=key)
        self.delays = delays

    def __call__(self, ts, y0):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Euler(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=lambda t: y0,
            saveat=diffrax.SaveAt(ts=ts, dense=True),
            adjoint=diffrax.DirectAdjoint(),
            delays=self.delays,
        )
        return solution.ys

@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
    y_pred = model(ti, yi[0])
    return jnp.mean((yi - y_pred) ** 2)

@eqx.filter_value_and_grad
def grad_loss_batch(model, ti, yi):
    y_pred = jax.vmap(model, (None, 0))(ti, yi[:, 0])
    return jnp.mean((yi - y_pred) ** 2)

if __name__ == "__main__":
    seed = np.random.randint(0, 1000)
    key = jrandom.PRNGKey(seed)
    ts = jnp.linspace(0.0, 1.0, 10)
    ys = jnp.ones_like(ts)[..., None]
    length_size, datasize = ys.shape

    delays = diffrax.Delays(delays=[lambda t, y, args: 1.0])
    model_dde = NeuralDDE(datasize, delays, key=key)

    with jax.check_tracer_leaks():
        loss, grads = grad_loss(model_dde, ts, ys)

    # Batched version
    ys = jnp.concatenate([2 * jnp.ones((1, 10, 1)), 3 * jnp.ones((1, 10, 1))], axis=0)

    # Silently side-effecting, no error ?
    loss, grads = grad_loss_batch(model_dde, ts, ys)

    # Batch leaked tracer or reporting false positive from Notes in link :
    # https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
    with jax.check_tracer_leaks():
        loss, grads = grad_loss_batch(model_dde, ts, ys)

Happy to discuss on this thread or more in depth via email (or other medium).