patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.42k stars 126 forks source link

How to use reversible integrators? #241

Open ameya98 opened 1 year ago

ameya98 commented 1 year ago

I wanted a time-reversible integrator, and according to the documentation SemiImplicitEuler() should work. However, I ran the following:

def f(t, p, args):
    return -p

def g(t, q, args):
    return -q

term = (diffrax.ODETerm(f), diffrax.ODETerm(g))
solver = diffrax.SemiImplicitEuler()

q0 = jnp.array([12., 3.])
p0 = jnp.array([1., 2.])

# Forward pass
y0 = (q0, p0)
solution = diffrax.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.01, y0=y0,
                               saveat=diffrax.SaveAt(t1=True, solver_state=True))
y1 = jax.tree_map(lambda arr: arr[0], solution.ys)

# Backward pass
solution = diffrax.diffeqsolve(term, solver, t0=1, t1=0, dt0=-0.01, y0=y1,
                               solver_state=solution.solver_state)
y0_computed = jax.tree_map(lambda arr: arr[0], solution.ys)

I see that y0_computed is:

(Array([11.797033 ,  2.9734273]),
 Array([0.85248  , 1.9949753])

which is close to, but different from y0. Am I using the wrong arguments?

patrick-kidger commented 1 year ago

Right! So getting exact reconstruction on the backward pass is possible, but doesn't really have a neat API at the moment. Generally speaking you need to be familiar with the numerical method involved.

In this case, SemiImplicitEuler works by unpacking y_a, y_b = y, then making an Euler step in y_a, and then making an Euler step in y_b. When we reverse this, we need to do this the other way around: take an Euler step in y_b and then take an Euler step in y_a. (More generally: write down the mathematics for a single step of the numerical solver, re-arrange it to recompute things backwards, and then figure out how to match that up against a solver run backwards.)

Here, this means that we can reconstruct things by flipping the two entries in y on the backward pass:

import diffrax
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

def f(t, p, args):
    return -p * 1.02

def g(t, q, args):
    return -q * 1.03

term = (diffrax.ODETerm(f), diffrax.ODETerm(g))
solver = diffrax.SemiImplicitEuler()

q0 = jnp.array([12., 3.])
p0 = jnp.array([1., 2.])

# Forward pass
y0 = (q0, p0)
solution = diffrax.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.01, y0=y0)
y1 = jax.tree_map(lambda arr: arr[0], solution.ys)
(q1, p1) = y1

# Backward pass
term_flip = (diffrax.ODETerm(g), diffrax.ODETerm(f))
y1_flip = (p1, q1)
solution = diffrax.diffeqsolve(term_flip, solver, t0=1, t1=0, dt0=-0.01, y0=y1_flip)
y0_recomputed = jax.tree_map(lambda arr: arr[0], solution.ys)
(p0_recomputed, q0_recomputed) = y0_recomputed
print(p0, q0)
print(p0_recomputed, q0_recomputed)
# [1. 2.] [12.  3.]
# [1. 2.] [12.  3.]

(An equally valid solution -- arguably better -- would be to define an alternate solver that reverses SemiImplicitEuler.)

Incidentally, SemiImplicitEuler actually doesn't use state, so as you can see I haven't bothered to pass it between the forward and backward pass. (And in general you might need to modify state between the forward and backward passes.)

Making it possible to use reversible methods with less knowledge of the numerics is something I'd like to support, so I'm going to mark this as a feature. Probably we'd first need to implement a few more reversible solvers to figure out what the appropriate abstractions are here.