patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.45k stars 131 forks source link

SaveAt and backward solving #382

Closed NightWinkle closed 9 months ago

NightWinkle commented 9 months ago

Hello,

It is not possible to specify a list of timesteps for a SaveAt when solving backwards, because some checking for the monotonicity of the list seems to be broken for that case.

import diffrax as dfx
import jax.numpy as jnp

def eq(t, x, args):
    return -2 * x

ts = jnp.linspace(0., 1., 10)
reverse_ts = jnp.flip(ts)
t0, t1 = reverse_ts[0], reverse_ts[-1]
dt0 = reverse_ts[1] - reverse_ts[0]
y0 = 1.

# This case with the ts in the reverse direction, and (*)
ts_save_0 = reverse_ts[1:-1]
saveat_0 = dfx.SaveAt(ts=ts_save_0)
sol0 = dfx.diffeqsolve(dfx.ODETerm(eq), dfx.Euler(), t0, t1, dt0, y0, saveat=saveat_0)

# (*) this case with ts in the forward direction, seem both to be broken
ts_save_1 = ts[1:-1]
saveat_1 = dfx.SaveAt(ts=ts_save_1)
sol1 = dfx.diffeqsolve(dfx.ODETerm(eq), dfx.Euler(), t0, t1, dt0, y0, saveat=saveat_1)

The error thrown is : XlaRuntimeError: saveat.ts must be increasing or decreasing.

patrick-kidger commented 9 months ago

Hi there! I think this is intentional: when solving backward then all timelike quantites -- off the top of my head that is SaveAt(ts=...) and PIDController(step_ts=..., jump_ts=...) -- should also be flipped.

NightWinkle commented 9 months ago

Hi there! I think this is intentional: when solving backward then all timelike quantites -- off the top of my head that is SaveAt(ts=...) and PIDController(step_ts=..., jump_ts=...) -- should also be flipped.

Sure but neither way work for me, both in a reverse and forward order for the array of ts. Maybe I am missing something if this is not reproducible.

patrick-kidger commented 9 months ago

Ah! So the sol0 should work / does work on my machine. Can you make sure that you've updated to the latest versions of JAX and jaxlib and Equinox and Diffrax, just in case any of those are the culprit?

NightWinkle commented 9 months ago

After multiple tries I got it to work but I must say I'm unsure what was causing the issue. Probably something between the keyboard and the computer.