patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.46k stars 133 forks source link

Possible issue with ReversibleHeun solver instability #417

Open AddisonHowe opened 6 months ago

AddisonHowe commented 6 months ago

I'm running into an issue using the ReversibleHeun solver, which may or may not just be an issue of choosing a proper step size. I've tried to make a MWE that still has the essence of my use case.

I have a quadratic potential function $\phi(x,y;t)$ that defines gradient dynamics, and that shifts in time so that the fixed point of the system moves around. I'm trying to simulate langevin dynamics, and diffrax has been really useful so far.

It looks though that the ReversibleHeun method becomes unstable, but in a bit of an odd way, and I can't quite figure out what the reason is. It notably persists without any noise in the system.

The example below defines the potential, defines the drift as its negative gradient, and uses a WeaklyDiagonalControlTerm for the isotropic, homogeneous noise. I show that the Heun method seems to work fine with a step size of $0.1$ in the zero-noise case while ReversibleHeun becomes unstable. As $dt$ decreases to $0.001$, ReversibleHeun appears to match.

I'm wondering if one should expect to require a small step size for the reversible heun method, or if there is something deeper going on. Any guidance would be appreciated.

I'm using diffrax version 0.5.0.

import numpy as np
import matplotlib.pyplot as plt

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrandom

from diffrax import VirtualBrownianTree, ODETerm, MultiTerm
from diffrax import WeaklyDiagonalControlTerm, ReversibleHeun, Heun
from diffrax import diffeqsolve, SaveAt

SEED = 123
rng = np.random.default_rng(seed=SEED)
key = jrandom.PRNGKey(seed=rng.integers(2**32))

def sigmoid(t, a, b, tcrit):
    """Sigmoid helper function"""
    return 0.5 * (a + b + (b - a) * jnp.tanh(t-tcrit))

def potential(t, y, args):
    """A quadratic potential function where the fixed point changes.
        p(x,y) = (x - u(t))^2 + (y - v(t))^2
    with u(t) and v(t) sigmoidal functions defined by the arguments in `args`
    """
    a1 = args['a1']
    b1 = args['b1']
    t1 = args['t1'] 
    a2 = args['a2']
    b2 = args['b2']
    t2 = args['t2']
    u = sigmoid(t, a1, b1, t1)
    v = sigmoid(t, a2, b2, t2)
    dy = y - jnp.array([u, v])
    return jnp.sum(dy * dy)

### Define drift and diffusion terms

def f(t, y, args):
    """Drift is defined via the gradient of the potential"""
    return -jax.jacfwd(potential, 1)(t, y, args)

def g(t, y, args):
    """Constant diffusion. Noise scale is a parameter `sigma` in `args`."""
    return args['sigma'] * jnp.ones(y.shape, dtype=jnp.float64)

# ### Demonstrate Heun Solver works but ReversibleHeun becomes unstable

dt0 = 0.1  # Initial solver step size: ReversibleHeun unstable
# dt0 = 0.01  # Initial solver step size: Beginning of an instability
# dt0 = 0.001  # Initial solver step size: Matches Heun method

args = {
    'a1': 0,  # x fixed point starts at 0, moves to 1 at t=5
    'b1': 1,
    't1': 5,

    'a2': 1, # y fixed point starts at 1, moves to 0 at t=5
    'b2': 0,
    't2': 5,

    'sigma': 0.0  # SET NOISE TO 0
}

max_steps = 4096 * 8  # increase max number of steps to be safe
vbt_tol = 1e-6  # tolerance on VirtualBrownianTree
t0 = 0.
t1 = 10.
y0 = jnp.array([0, 0], dtype=jnp.float64)  # (0, 0) initial condition

key, subkey = jrandom.split(key, 2)

brownian_motion = VirtualBrownianTree(
    t0, t1, tol=vbt_tol, 
    shape=(len(y0),), 
    key=subkey
)

terms = MultiTerm(
    ODETerm(f), 
    WeaklyDiagonalControlTerm(g, brownian_motion)
)

ts_save = jnp.linspace(t0, t1, 101)
saveat = SaveAt(ts=ts_save)

sol_heun = diffeqsolve(
    terms, Heun(), 
    t0, t1, dt0=dt0, 
    y0=y0, 
    saveat=saveat,
    args=args,
    max_steps=max_steps,
)

sol_rev_heun = diffeqsolve(
    terms, ReversibleHeun(), 
    t0, t1, dt0=dt0, 
    y0=y0, 
    saveat=saveat,
    args=args,
    max_steps=max_steps,
)

fig, [ax1, ax2] = plt.subplots(2, 1)
ax1.plot(ts_save, sol_heun.ys, label=['x (heun)','y (heun)'])
ax1.plot(
    ts_save, sigmoid(ts_save, args['a1'], args['b1'], args['t1']),
    ':', label='fixed point x'
)
ax1.plot(
    ts_save, sigmoid(ts_save, args['a2'], args['b2'], args['t2']),
    ':', label='fixed point y'
)
ax1.set_xlabel('t')
ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax1.set_title("Heun Method")

ax2.plot(ts_save, sol_rev_heun.ys, label=['x (rev heun)','y (rev heun)'])
ax2.plot(
    ts_save, sigmoid(ts_save, args['a1'], args['b1'], args['t1']),
    ':', label='fixed point x'
)
ax2.plot(
    ts_save, sigmoid(ts_save, args['a2'], args['b2'], args['t2']),
    ':', label='fixed point y'
)
ax2.set_xlabel('t')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax2.set_title("Reversible Heun Method")

fig.suptitle(f"No noise, dt0={dt0}")
plt.tight_layout()
plt.show()

And here's my environment...

absl-py                   2.0.0              pyhd8ed1ab_0    conda-forge
appnope                   0.1.3              pyhd8ed1ab_0    conda-forge
asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
brotli                    1.1.0                hb547adb_1    conda-forge
brotli-bin                1.1.0                hb547adb_1    conda-forge
brotli-python             1.1.0            py39hb198ff7_1    conda-forge
bzip2                     1.0.8                h93a5062_5    conda-forge
c-ares                    1.23.0               h93a5062_0    conda-forge
ca-certificates           2024.2.2             hf0a4a13_0    conda-forge
certifi                   2024.2.2           pyhd8ed1ab_0    conda-forge
charset-normalizer        3.3.2              pyhd8ed1ab_0    conda-forge
chex                      0.1.85             pyhd8ed1ab_0    conda-forge
colorama                  0.4.6              pyhd8ed1ab_0    conda-forge
comm                      0.1.4              pyhd8ed1ab_0    conda-forge
contourpy                 1.2.0            py39he9de807_0    conda-forge
cycler                    0.12.1             pyhd8ed1ab_0    conda-forge
debugpy                   1.8.0            py39hb198ff7_1    conda-forge
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
diffrax                   0.5.0                    pypi_0    pypi
equinox                   0.11.2             pyhd8ed1ab_0    conda-forge
exceptiongroup            1.2.0              pyhd8ed1ab_0    conda-forge
executing                 2.0.1              pyhd8ed1ab_0    conda-forge
filelock                  3.13.1             pyhd8ed1ab_0    conda-forge
fonttools                 4.46.0           py39h17cfd9d_0    conda-forge
freetype                  2.12.1               hadb7bae_2    conda-forge
gmp                       6.3.0                h965bd2d_0    conda-forge
gmpy2                     2.1.2            py39h0b4f9c6_1    conda-forge
idna                      3.6                pyhd8ed1ab_0    conda-forge
importlib-metadata        7.0.0              pyha770c72_0    conda-forge
importlib-resources       6.1.1              pyhd8ed1ab_0    conda-forge
importlib_metadata        7.0.0                hd8ed1ab_0    conda-forge
importlib_resources       6.1.1              pyhd8ed1ab_0    conda-forge
iniconfig                 2.0.0              pyhd8ed1ab_0    conda-forge
ipykernel                 6.26.0             pyh3cd1d5f_0    conda-forge
ipympl                    0.9.4              pyhd8ed1ab_0    conda-forge
ipython                   8.18.1             pyh31011fe_2    conda-forge
ipython_genutils          0.2.0                      py_1    conda-forge
ipywidgets                8.1.2              pyhd8ed1ab_0    conda-forge
jax                       0.4.23             pyhd8ed1ab_0    conda-forge
jaxlib                    0.4.19          cpu_py39he2ef314_0    conda-forge
jaxtyping                 0.2.25             pyhd8ed1ab_0    conda-forge
jedi                      0.19.1             pyhd8ed1ab_0    conda-forge
jinja2                    3.1.2              pyhd8ed1ab_1    conda-forge
joblib                    1.3.2              pyhd8ed1ab_0    conda-forge
jupyter_client            8.6.0              pyhd8ed1ab_0    conda-forge
jupyter_core              5.5.0            py39h2804cbe_0    conda-forge
jupyterlab_widgets        3.0.10             pyhd8ed1ab_0    conda-forge
kiwisolver                1.4.5            py39hbd775c9_1    conda-forge
lcms2                     2.16                 ha0e7c42_0    conda-forge
lerc                      4.0.0                h9a09cb3_0    conda-forge
libabseil                 20230802.1      cxx17_h13dd4ca_0    conda-forge
libblas                   3.9.0           20_osxarm64_openblas    conda-forge
libbrotlicommon           1.1.0                hb547adb_1    conda-forge
libbrotlidec              1.1.0                hb547adb_1    conda-forge
libbrotlienc              1.1.0                hb547adb_1    conda-forge
libcblas                  3.9.0           20_osxarm64_openblas    conda-forge
libcxx                    16.0.6               h4653b0c_0    conda-forge
libdeflate                1.19                 hb547adb_0    conda-forge
libffi                    3.4.2                h3422bc3_5    conda-forge
libgfortran               5.0.0           13_2_0_hd922786_1    conda-forge
libgfortran5              13.2.0               hf226fd6_1    conda-forge
libgrpc                   1.58.2               h19be7b0_0    conda-forge
libjpeg-turbo             3.0.0                hb547adb_1    conda-forge
liblapack                 3.9.0           20_osxarm64_openblas    conda-forge
libopenblas               0.3.25          openmp_h6c19121_0    conda-forge
libpng                    1.6.39               h76d750c_0    conda-forge
libprotobuf               4.24.3               hf590ac1_1    conda-forge
libre2-11                 2023.06.02           h1753957_0    conda-forge
libsodium                 1.0.18               h27ca646_1    conda-forge
libsqlite                 3.44.2               h091b4b1_0    conda-forge
libtiff                   4.6.0                ha8a6c65_2    conda-forge
libuv                     1.46.0               hb547adb_0    conda-forge
libwebp-base              1.3.2                hb547adb_0    conda-forge
libxcb                    1.15                 hf346824_0    conda-forge
libzlib                   1.2.13               h53f4e23_5    conda-forge
lineax                    0.0.4                    pypi_0    pypi
llvm-openmp               17.0.6               hcd81f8e_0    conda-forge
markupsafe                2.1.3            py39h0f82c59_1    conda-forge
matplotlib                3.8.2            py39hdf13c20_0    conda-forge
matplotlib-base           3.8.2            py39h1a09f3e_0    conda-forge
matplotlib-inline         0.1.6              pyhd8ed1ab_0    conda-forge
ml_dtypes                 0.3.1            py39hf8cecc8_2    conda-forge
mpc                       1.3.1                h91ba8db_0    conda-forge
mpfr                      4.2.1                h9546428_0    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
ncurses                   6.4                  h463b476_2    conda-forge
nest-asyncio              1.5.8              pyhd8ed1ab_0    conda-forge
networkx                  3.2.1              pyhd8ed1ab_0    conda-forge
nomkl                     1.0                  h5ca1d4c_0    conda-forge
numpy                     1.26.2           py39heee92a0_0    conda-forge
openjpeg                  2.5.0                h4c1507b_3    conda-forge
openssl                   3.3.0                h0d3ecfb_0    conda-forge
opt-einsum                3.3.0                hd8ed1ab_2    conda-forge
opt_einsum                3.3.0              pyhc1e730c_2    conda-forge
optax                     0.1.7              pyhd8ed1ab_0    conda-forge
optimistix                0.0.6                    pypi_0    pypi
packaging                 23.2               pyhd8ed1ab_0    conda-forge
parso                     0.8.3              pyhd8ed1ab_0    conda-forge
pexpect                   4.8.0              pyh1a96a4e_2    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    10.1.0           py39h755f0b7_0    conda-forge
pip                       23.3.1             pyhd8ed1ab_0    conda-forge
platformdirs              4.0.0              pyhd8ed1ab_0    conda-forge
plnn                      0.1.0                     dev_0    <develop>
pluggy                    1.3.0              pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.41             pyha770c72_0    conda-forge
psutil                    5.9.5            py39h0f82c59_1    conda-forge
pthread-stubs             0.4               h27ca646_1001    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.2              pyhd8ed1ab_0    conda-forge
pygments                  2.17.2             pyhd8ed1ab_0    conda-forge
pyparsing                 3.1.1              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
pytest                    7.4.3              pyhd8ed1ab_0    conda-forge
python                    3.9.18          hfa1ae8a_0_cpython    conda-forge
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python_abi                3.9                      4_cp39    conda-forge
pytorch                   2.0.0           cpu_generic_py39hdf1d804_3    conda-forge
pyzmq                     25.1.1           py39he0a3c8b_2    conda-forge
re2                       2023.06.02           h6135d0a_0    conda-forge
readline                  8.2                  h92ec313_1    conda-forge
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
scikit-learn              1.3.2            py39h172c841_2    conda-forge
scipy                     1.11.4           py39h36c428d_0    conda-forge
setuptools                68.2.2             pyhd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sleef                     3.5.1                h156473d_2    conda-forge
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
sympy                     1.12            pypyh9d50eac_103    conda-forge
threadpoolctl             3.2.0              pyha21a80b_0    conda-forge
tk                        8.6.13               h5083fa2_1    conda-forge
tomli                     2.0.1              pyhd8ed1ab_0    conda-forge
toolz                     0.12.0             pyhd8ed1ab_0    conda-forge
torchvision               0.15.2          cpu_py39ha53f654_4    conda-forge
tornado                   6.3.3            py39h0f82c59_1    conda-forge
tqdm                      4.66.2             pyhd8ed1ab_0    conda-forge
traitlets                 5.14.0             pyhd8ed1ab_0    conda-forge
typeguard                 2.13.3             pyhd8ed1ab_0    conda-forge
typing-extensions         4.9.0                hd8ed1ab_0    conda-forge
typing_extensions         4.9.0              pyha770c72_0    conda-forge
tzdata                    2023c                h71feb2d_0    conda-forge
unicodedata2              15.1.0           py39h0f82c59_0    conda-forge
urllib3                   2.1.0              pyhd8ed1ab_0    conda-forge
wcwidth                   0.2.12             pyhd8ed1ab_0    conda-forge
wheel                     0.42.0             pyhd8ed1ab_0    conda-forge
widgetsnbextension        4.0.10             pyhd8ed1ab_0    conda-forge
xorg-libxau               1.0.11               hb547adb_0    conda-forge
xorg-libxdmcp             1.1.3                h27ca646_0    conda-forge
xz                        5.2.6                h57fd34a_0    conda-forge
zeromq                    4.3.5                h965bd2d_0    conda-forge
zipp                      3.17.0             pyhd8ed1ab_0    conda-forge
zstd                      1.5.5                h4f39d0f_0    conda-forge
patrick-kidger commented 6 months ago

I think this is expected! ReversibleHeun is quite an unstable solver. It often requires smaller step sizes than other solvers. This is partly because it retains additional memory between evaluations (other than just the evolving state). I could believe that this memory, combined with the "moving target" nature of your problem makes it a particularly poor fit.