Open miguelgondu opened 7 months ago
So support for DDEs is something we've been noodling over in Diffrax for a while, see #169.
The main reason that PR stalled is that solving general DDEs requires solving several of nonlinear optimisation problems, and at the time Optimistix did not exist yet.
Now that it does we have been meaning to revisit that PR, fix it up to use the new root-finding functionality that is now available in Optimistix.
I must acknowledge that this is (a) fairly technical code, but also conversely (b) that the hard parts are already written.
If you'd be interested in reviving that PR then this is still a feature I'd be happy to see in Diffrax.
Hello there,
As mentionned by @patrick-kidger, most of the code itself is there (I would say 90%) and functional but some jax
related bugs are still there (e.g. tracer leakage) that makes backpropagation an issue. Happy to discuss if you are interested in giving a hand.
Hi both,
Thanks for the implementation, @thibmonsel! My student and I have been using your dde.ipynb
example on a different system of DDEs, and it worked almost out-of-the-box. If I understand correctly, the neuraldde.ipynb
example is not finished yet, right?
I would be happy to help, but I fear this is above my skill level (I'm only now starting to use jax
). If you give me pointers on how to get started, I could give it a try, but I can't promise much.
That's great to here !
Integrating DDEs itself should be more than robust (so dde.ipynb
should work fine).
However, fitting a DDE with a neural net will shootout some Exception: Leaked trace
when combined with jax.check_tracer_leaks()
.
A MWE for this could be :
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import diffrax
class Func(eqx.Module):
linear: eqx.nn.Linear
def __init__(self, data_size, *, key, **kwargs):
super().__init__(**kwargs)
self.linear = eqx.nn.Linear(2 * data_size, data_size, key=key)
def __call__(self, t, y, args, *, history):
return self.linear(jnp.hstack([y, *history]))
class NeuralDDE(eqx.Module):
func: Func
delays: diffrax.Delays
def __init__(self, data_size, delays, *, key, **kwargs):
super().__init__(**kwargs)
self.func = Func(data_size, key=key)
self.delays = delays
def __call__(self, ts, y0):
solution = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Euler(),
t0=ts[0],
t1=ts[-1],
dt0=ts[1] - ts[0],
y0=lambda t: y0,
saveat=diffrax.SaveAt(ts=ts, dense=True),
adjoint=diffrax.DirectAdjoint(),
delays=self.delays,
)
return solution.ys
@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
y_pred = model(ti, yi[0])
return jnp.mean((yi - y_pred) ** 2)
@eqx.filter_value_and_grad
def grad_loss_batch(model, ti, yi):
y_pred = jax.vmap(model, (None, 0))(ti, yi[:, 0])
return jnp.mean((yi - y_pred) ** 2)
if __name__ == "__main__":
seed = np.random.randint(0, 1000)
key = jrandom.PRNGKey(seed)
ts = jnp.linspace(0.0, 1.0, 10)
ys = jnp.ones_like(ts)[..., None]
length_size, datasize = ys.shape
delays = diffrax.Delays(delays=[lambda t, y, args: 1.0])
model_dde = NeuralDDE(datasize, delays, key=key)
with jax.check_tracer_leaks():
loss, grads = grad_loss(model_dde, ts, ys)
# Batched version
ys = jnp.concatenate([2 * jnp.ones((1, 10, 1)), 3 * jnp.ones((1, 10, 1))], axis=0)
# Silently side-effecting, no error ?
loss, grads = grad_loss_batch(model_dde, ts, ys)
# Batch leaked tracer or reporting false positive from Notes in link :
# https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
with jax.check_tracer_leaks():
loss, grads = grad_loss_batch(model_dde, ts, ys)
Happy to discuss on this thread or more in depth via email (or other medium).
Dear Patrick,
Thanks for this library, it's pretty neat!
I'm currently supervising a bachelor's students thesis on NeuralODEs, and we've been meaning to use
diffrax
to study delay differential equations. I'm currently trying to implement a delay differential equations solver insidediffrax
, and I would like to know if I'm in the right track.For context, let me give a brief overview of how delay differential equations work, and how such a solver could be implemented. In it's simplest form, a (constant) Delay Differential Equation (DDE) has a vector field $f$ that depends not only on the current state $y(t)$, but also on $y(t-\tau)$ where $\tau\in\mathbb{R}_{>0}$. In other words
$$y'(t) = f(t, y(t), y(t-\tau)).$$
Initial value problems involving DDEs provide a history instead of a single initial value ($y(t) = \phi(t)$ for $t \leq 0$, for example), and are solved in chunks using the "method of steps". Shortly put, one solves an IVP in intervals of the form $[t_0 + k\tau, t_0 + (k+1)\tau]$. (More details in Chap. 9 of this reference).
In practice, solving a DDE numerically can be done by selecting the right step-size such that $y(t-\tau)$ is always in the grid. To predict
y_{t+1}
we need to evaluate the vector field iny_{t}
and in somey_{t-k}
corresponding to $y(t-\tau)$. Other ways of doing it would involve e.g. building an Hermite interpolation between the two relevant points in the grid if $y(t-\tau)$ happens to lie outside of the grid, but I plan to focus on the first alternative.How could I adapt
diffrax
to let me pass terms of the formf(t, y(t), y(t-tau))dt
? I imagine I have to implement a newDelayTerm
that inherits fromAbstractTerm
with a differentvf
method; since solvers need to evaluate these vector fields, I imagine I would also need to modify/create a new one in whichvf
is called with the right signature, right?I'm of course happy to contribute the implementation to diffrax once it's up and running.