nathanaelbosch / parallel-in-time-ode-filters

Parallel-in-Time ODE Filters in Jax
6 stars 0 forks source link

Parallel-in-time Probabilistic Numerical ODE Solvers

This repo contains the implementation and experiment code for the paper "Parallel-in-Time Probabilistic Numerical ODE Solvers", available on arXiv.

Project environment setup

The project uses poetry. After installing poetry, you should be able to initialize the project with just

poetry install

Usage

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from pof.ivp import fitzhughnagumo
from pof.solver import solve, sequential_eks_solve

ivp = fitzhughnagumo()

ts_par = jnp.linspace(0, 100, 100)
ys_par, info_par = solve(f=ivp.f, y0=ivp.y0, ts=ts_par, order=3, init="constant")

ts_seq = jnp.linspace(0, 100, 300)
ys_seq, info_seq = sequential_eks_solve(f=ivp.f, y0=ivp.y0, ts=ts_seq, order=3)

def plot_result(ts, ys, ax=None):
    means, chol_covs = ys
    covs = jax.vmap(lambda c: c @ c.T, in_axes=0)(chol_covs)

    if ax is None:
        fig, ax = plt.subplots(1, 1)

    ax.plot(ts, means, marker="o")
    for i in range(means.shape[1]):
        ax.fill_between(
            ts,
            means[:, i] - 2 * jnp.sqrt(covs[:, i, i]),
            means[:, i] + 2 * jnp.sqrt(covs[:, i, i]),
            alpha=0.2,
            color=f"C{i}",
        )
    return ax

fig, axes = plt.subplots(2, 1)
plot_result(ts_seq, ys_seq, ax=axes[0])
plot_result(ts_par, ys_par, ax=axes[1])
axes[0].set_ylim(-3, 3)
axes[0].set_title("Sequential")
axes[1].set_ylim(-3, 3)
axes[1].set_title("Parallel")
fig.tight_layout()
plt.show()

README Figure

Testing

Just use tox:

tox -e py3

Or even just tox to also run black and isort.

Reference

@misc{bosch2023parallelintime,
      title={Parallel-in-Time Probabilistic Numerical ODE Solvers}, 
      author={Nathanael Bosch and Adrien Corenflos and Fatemeh Yaghoobi and Filip Tronarp and Philipp Hennig and Simo Särkkä},
      year={2023},
      eprint={2310.01145},
      archivePrefix={arXiv},
      primaryClass={math.NA}
}