Possible issue with ReversibleHeun solver instability #417

AddisonHowe commented 6 months ago

I'm running into an issue using the ReversibleHeun solver, which may or may not just be an issue of choosing a proper step size. I've tried to make a MWE that still has the essence of my use case.

I have a quadratic potential function $\phi(x,y;t)$ that defines gradient dynamics, and that shifts in time so that the fixed point of the system moves around. I'm trying to simulate langevin dynamics, and diffrax has been really useful so far.

It looks though that the ReversibleHeun method becomes unstable, but in a bit of an odd way, and I can't quite figure out what the reason is. It notably persists without any noise in the system.

The example below defines the potential, defines the drift as its negative gradient, and uses a WeaklyDiagonalControlTerm for the isotropic, homogeneous noise. I show that the Heun method seems to work fine with a step size of $0.1$ in the zero-noise case while ReversibleHeun becomes unstable. As $dt$ decreases to $0.001$, ReversibleHeun appears to match.

I'm wondering if one should expect to require a small step size for the reversible heun method, or if there is something deeper going on. Any guidance would be appreciated.

I'm using diffrax version 0.5.0.

import numpy as np
import matplotlib.pyplot as plt

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrandom

from diffrax import VirtualBrownianTree, ODETerm, MultiTerm
from diffrax import WeaklyDiagonalControlTerm, ReversibleHeun, Heun
from diffrax import diffeqsolve, SaveAt

SEED = 123
rng = np.random.default_rng(seed=SEED)
key = jrandom.PRNGKey(seed=rng.integers(2**32))

def sigmoid(t, a, b, tcrit):
    """Sigmoid helper function"""
    return 0.5 * (a + b + (b - a) * jnp.tanh(t-tcrit))

def potential(t, y, args):
    """A quadratic potential function where the fixed point changes.
        p(x,y) = (x - u(t))^2 + (y - v(t))^2
    with u(t) and v(t) sigmoidal functions defined by the arguments in `args`
    a1 = args['a1']
    b1 = args['b1']
    t1 = args['t1'] 
    a2 = args['a2']
    b2 = args['b2']
    t2 = args['t2']
    u = sigmoid(t, a1, b1, t1)
    v = sigmoid(t, a2, b2, t2)
    dy = y - jnp.array([u, v])
    return jnp.sum(dy * dy)

### Define drift and diffusion terms

def f(t, y, args):
    """Drift is defined via the gradient of the potential"""
    return -jax.jacfwd(potential, 1)(t, y, args)

def g(t, y, args):
    """Constant diffusion. Noise scale is a parameter `sigma` in `args`."""
    return args['sigma'] * jnp.ones(y.shape, dtype=jnp.float64)

# ### Demonstrate Heun Solver works but ReversibleHeun becomes unstable

dt0 = 0.1  # Initial solver step size: ReversibleHeun unstable
# dt0 = 0.01  # Initial solver step size: Beginning of an instability
# dt0 = 0.001  # Initial solver step size: Matches Heun method

args = {
    'a1': 0,  # x fixed point starts at 0, moves to 1 at t=5
    'b1': 1,
    't1': 5,

    'a2': 1, # y fixed point starts at 1, moves to 0 at t=5
    'b2': 0,
    't2': 5,

    'sigma': 0.0  # SET NOISE TO 0

max_steps = 4096 * 8  # increase max number of steps to be safe
vbt_tol = 1e-6  # tolerance on VirtualBrownianTree
t0 = 0.
t1 = 10.
y0 = jnp.array([0, 0], dtype=jnp.float64)  # (0, 0) initial condition

key, subkey = jrandom.split(key, 2)

brownian_motion = VirtualBrownianTree(
    t0, t1, tol=vbt_tol, 

terms = MultiTerm(
    WeaklyDiagonalControlTerm(g, brownian_motion)

ts_save = jnp.linspace(t0, t1, 101)
saveat = SaveAt(ts=ts_save)

sol_heun = diffeqsolve(
    terms, Heun(), 
    t0, t1, dt0=dt0, 

sol_rev_heun = diffeqsolve(
    terms, ReversibleHeun(), 
    t0, t1, dt0=dt0, 

fig, [ax1, ax2] = plt.subplots(2, 1)
ax1.plot(ts_save, sol_heun.ys, label=['x (heun)','y (heun)'])
    ts_save, sigmoid(ts_save, args['a1'], args['b1'], args['t1']),
    ':', label='fixed point x'
    ts_save, sigmoid(ts_save, args['a2'], args['b2'], args['t2']),
    ':', label='fixed point y'
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.set_title("Heun Method")

ax2.plot(ts_save, sol_rev_heun.ys, label=['x (rev heun)','y (rev heun)'])
    ts_save, sigmoid(ts_save, args['a1'], args['b1'], args['t1']),
    ':', label='fixed point x'
    ts_save, sigmoid(ts_save, args['a2'], args['b2'], args['t2']),
    ':', label='fixed point y'
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.set_title("Reversible Heun Method")

fig.suptitle(f"No noise, dt0={dt0}")

And here's my environment...

patrick-kidger commented 6 months ago

I think this is expected! ReversibleHeun is quite an unstable solver. It often requires smaller step sizes than other solvers. This is partly because it retains additional memory between evaluations (other than just the evolving state). I could believe that this memory, combined with the "moving target" nature of your problem makes it a particularly poor fit.