Open ameya98 opened 1 year ago
Right! So getting exact reconstruction on the backward pass is possible, but doesn't really have a neat API at the moment. Generally speaking you need to be familiar with the numerical method involved.
In this case, SemiImplicitEuler
works by unpacking y_a, y_b = y
, then making an Euler step in y_a
, and then making an Euler step in y_b
. When we reverse this, we need to do this the other way around: take an Euler step in y_b
and then take an Euler step in y_a
. (More generally: write down the mathematics for a single step of the numerical solver, re-arrange it to recompute things backwards, and then figure out how to match that up against a solver run backwards.)
Here, this means that we can reconstruct things by flipping the two entries in y
on the backward pass:
import diffrax
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
def f(t, p, args):
return -p * 1.02
def g(t, q, args):
return -q * 1.03
term = (diffrax.ODETerm(f), diffrax.ODETerm(g))
solver = diffrax.SemiImplicitEuler()
q0 = jnp.array([12., 3.])
p0 = jnp.array([1., 2.])
# Forward pass
y0 = (q0, p0)
solution = diffrax.diffeqsolve(term, solver, t0=0, t1=1, dt0=0.01, y0=y0)
y1 = jax.tree_map(lambda arr: arr[0], solution.ys)
(q1, p1) = y1
# Backward pass
term_flip = (diffrax.ODETerm(g), diffrax.ODETerm(f))
y1_flip = (p1, q1)
solution = diffrax.diffeqsolve(term_flip, solver, t0=1, t1=0, dt0=-0.01, y0=y1_flip)
y0_recomputed = jax.tree_map(lambda arr: arr[0], solution.ys)
(p0_recomputed, q0_recomputed) = y0_recomputed
print(p0, q0)
print(p0_recomputed, q0_recomputed)
# [1. 2.] [12. 3.]
# [1. 2.] [12. 3.]
(An equally valid solution -- arguably better -- would be to define an alternate solver that reverses SemiImplicitEuler
.)
Incidentally, SemiImplicitEuler
actually doesn't use state
, so as you can see I haven't bothered to pass it between the forward and backward pass. (And in general you might need to modify state
between the forward and backward passes.)
Making it possible to use reversible methods with less knowledge of the numerics is something I'd like to support, so I'm going to mark this as a feature. Probably we'd first need to implement a few more reversible solvers to figure out what the appropriate abstractions are here.
I wanted a time-reversible integrator, and according to the documentation SemiImplicitEuler() should work. However, I ran the following:
I see that y0_computed is:
which is close to, but different from y0. Am I using the wrong arguments?