pnkraemer / probdiffeq

Probabilistic solvers for differential equations in JAX. Adaptive ODE solvers with calibration, state-space model factorisations, and custom information operators. Compatible with the broader JAX scientific computing ecosystem.
https://pnkraemer.github.io/probdiffeq/
MIT License
29 stars 2 forks source link

Restart IVP solver from terminal value #676

Open lahramon opened 8 months ago

lahramon commented 8 months ago

Hi,

I am trying to integrate a vector field with jumps in the derivative and experience a high integration error when the derivative jumps. As a solution attempt, I would like to split the integration interval in two parts.

Is there a way to restart a filter from the last value of a given posterior, i.e., to continue the integration from the end point of a previous subinterval? I have tried extracting the last state of the MarkovSeq from the posterior and giving it to a filter; however, from the second derivative onwards, the covariances do not match anymore (see example below).

If that works, ideally I would like to run a smoother on this "stitched-together" Markov sequence. Is it possible with probdiffeq to just call the backwards-stepping smoother on a given filtering sequence? Then I could also stitch the smoothing solution together afterward hopefully.

Here is my attempt so far to stitch together two filtering solutions based on your "Posterior Uncertainties" example (but with an uncalibrated solver) :

# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.15.2
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# # Posterior uncertainties

# +
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps
from jax.config import config

from probdiffeq import ivpsolve, adaptive, timestep
from probdiffeq.impl import impl
from probdiffeq.util.doc_util import notebook
from probdiffeq.solvers import uncalibrated, calibrated, solution, markov
from probdiffeq.taylor import autodiff
from probdiffeq.solvers.strategies import filters, smoothers, fixedpoint
from probdiffeq.solvers.strategies.components import corrections, priors

# +
plt.rcParams.update(notebook.plot_config())

if not backend.has_been_selected:
    backend.select("jax")  # ivp examples in jax

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
# -

impl.select("isotropic", ode_shape=(2,))

# Set an example problem.
#
# Solve the problem on a low resolution and short time-span to achieve large uncertainty.

# +
f, u0, (t0, t1), f_args = ivps.lotka_volterra()

@jax.jit
def vf(*ys, t):
    return f(*ys, *f_args)

# -

# ## Filter

# +
ibm = priors.ibm_adaptive(num_derivatives=4)
ts0 = corrections.ts0()
# strategy = smoothers.smoother_adaptive(ibm, ts0)
strategy = filters.filter_adaptive(ibm, ts0)
solver = uncalibrated.solver(strategy)
# solver = calibrated.mle(filters.filter_adaptive(ibm, ts0))
# adaptive_solver = adaptive.adaptive(solver, atol=1e-2, rtol=1e-2)

tf = t0 + 2.0
tf_half = t0 + 1.0

num_steps_half = 10
num_steps = 2 * num_steps_half

ts = jnp.linspace(t0, tf, endpoint=True, num=num_steps+1)
ts_half_1 = jnp.linspace(t0, tf_half, endpoint=True, num=num_steps_half+1)
ts_half_2 = jnp.linspace(tf_half, tf, endpoint=True, num=num_steps_half+1)
# ts_half_1_onemore = ts[0:num_steps_half+2]

# +
tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4)
init = solver.initial_condition(tcoeffs, output_scale=10.0)
sol = ivpsolve.solve_fixed_grid(
    vf, init, grid=ts, solver=solver
)

marginals = solution.calibrate(sol.marginals, output_scale=sol.output_scale)
posterior = solution.calibrate(sol.posterior, output_scale=sol.output_scale)
# posterior = markov.select_terminal(posterior)

# +
# same for first half
tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=4)
init_half_1 = solver.initial_condition(tcoeffs, output_scale=10.0)
sol_half_1 = ivpsolve.solve_fixed_grid(
    vf, init_half_1, grid=ts_half_1, solver=solver
    # vf, init_half_1, grid=ts_half_1_onemore, solver=solver
)

marginals_half_1 = solution.calibrate(sol_half_1.marginals, output_scale=sol_half_1.output_scale)
posterior_half_1 = solution.calibrate(sol_half_1.posterior, output_scale=sol_half_1.output_scale)
# posterior_half_1_terminal = markov.select_terminal(posterior_half_1)

# -

import jax.tree_util as jaxtu
posterior_half_1_structure = jaxtu.tree_structure(posterior_half_1)
posterior_half_1_leaves = jaxtu.tree_leaves(posterior_half_1)
posterior_half_1_leaves_ziplist = list(zip(*posterior_half_1_leaves))

# +
# initialize second half with posterior from first half
init_half_2 = (jaxtu.tree_unflatten(posterior_half_1_structure, posterior_half_1_leaves_ziplist[-1]), sol_half_1.output_scale[-1])
sol_half_2 = ivpsolve.solve_fixed_grid(
    vf, init_half_2, grid=ts_half_2, solver=solver
)

marginals_half_2 = solution.calibrate(sol_half_2.marginals, output_scale=sol_half_2.output_scale)
posterior_half_2 = solution.calibrate(sol_half_2.posterior, output_scale=sol_half_2.output_scale)
# posterior_half_2_terminal = markov.select_terminal(posterior_half_2)

# +
# key = jax.random.PRNGKey(seed=1)
# (qoi, samples), _init = markov.sample(key, posterior, shape=(2,), reverse=True)

# +
_, num_derivatives, _ = marginals.mean.shape

fig, axes_all = plt.subplots(
    nrows=2,
    ncols=num_derivatives,
    sharex=True,
    tight_layout=True,
    figsize=(8, 3),
)

for i, axes_cols in enumerate(axes_all.T):
    ms = marginals.mean[:, i, :]
    # samps = samples[..., i, :]
    ls = marginals.cholesky[:, i, :]
    stds = jnp.sqrt(jnp.einsum("jn,jn->j", ls, ls))

    if i == 1:
        axes_cols[0].set_title(f"{i}st deriv.")
    elif i == 2:
        axes_cols[0].set_title(f"{i}nd deriv.")
    elif i == 3:
        axes_cols[0].set_title(f"{i}rd deriv.")
    else:
        axes_cols[0].set_title(f"{i}th deriv.")

    axes_cols[0].plot(sol.t, ms, marker="None")
    # for s in samps:
    #     axes_cols[0].plot(
    #         sol.t[:-1], s[..., 0], color="C0", linewidth=0.35, marker="None"
    #     )
    #     axes_cols[0].plot(
    #         sol.t[:-1], s[..., 1], color="C1", linewidth=0.35, marker="None"
    #     )
    for m in ms.T:
        axes_cols[0].fill_between(sol.t, m - 1.96 * stds, m + 1.96 * stds, alpha=0.3)

    axes_cols[1].semilogy(sol.t, stds, marker="None")

# first half
for i, axes_cols in enumerate(axes_all.T):
    ms_half_1 = marginals_half_1.mean[:, i, :]
    # samps = samples[..., i, :]
    ls_half_1 = marginals_half_1.cholesky[:, i, :]
    stds_half_1 = jnp.sqrt(jnp.einsum("jn,jn->j", ls_half_1, ls_half_1))

    if i == 1:
        axes_cols[0].set_title(f"{i}st deriv.")
    elif i == 2:
        axes_cols[0].set_title(f"{i}nd deriv.")
    elif i == 3:
        axes_cols[0].set_title(f"{i}rd deriv.")
    else:
        axes_cols[0].set_title(f"{i}th deriv.")

    axes_cols[0].plot(sol_half_1.t, ms_half_1, marker="None")
    # for s in samps:
    #     axes_cols[0].plot(
    #         sol.t[:-1], s[..., 0], color="C0", linewidth=0.35, marker="None"
    #     )
    #     axes_cols[0].plot(
    #         sol.t[:-1], s[..., 1], color="C1", linewidth=0.35, marker="None"
    #     )
    for m_half_1 in ms_half_1.T:
        axes_cols[0].fill_between(sol_half_1.t, m_half_1 - 1.96 * stds_half_1, m_half_1 + 1.96 * stds_half_1, alpha=0.3)

    axes_cols[1].semilogy(sol_half_1.t, stds_half_1, marker="None")

# second half
for i, axes_cols in enumerate(axes_all.T):
    ms_half_2 = marginals_half_2.mean[:, i, :]
    # samps = samples[..., i, :]
    ls_half_2 = marginals_half_2.cholesky[:, i, :]
    stds_half_2 = jnp.sqrt(jnp.einsum("jn,jn->j", ls_half_2, ls_half_2))

    if i == 1:
        axes_cols[0].set_title(f"{i}st deriv.")
    elif i == 2:
        axes_cols[0].set_title(f"{i}nd deriv.")
    elif i == 3:
        axes_cols[0].set_title(f"{i}rd deriv.")
    else:
        axes_cols[0].set_title(f"{i}th deriv.")

    axes_cols[0].plot(sol_half_2.t, ms_half_2, marker="None")
    # for s in samps:
    #     axes_cols[0].plot(
    #         sol.t[:-1], s[..., 0], color="C0", linewidth=0.35, marker="None"
    #     )
    #     axes_cols[0].plot(
    #         sol.t[:-1], s[..., 1], color="C1", linewidth=0.35, marker="None"
    #     )
    for m_half_2 in ms_half_2.T:
        axes_cols[0].fill_between(sol_half_2.t, m_half_2 - 1.96 * stds_half_2, m_half_2 + 1.96 * stds_half_2, alpha=0.3)

    axes_cols[1].semilogy(sol_half_2.t, stds_half_2, marker="None")

plt.show()
# -

sol_stitched = jnp.vstack((sol_half_1.u[:-1,:], sol_half_2.u))

sol.u - sol_stitched

Thanks a lot, and best regards!

pnkraemer commented 8 months ago

Hi!

Do I understand correctly that you want to have two solves with two different vector fields, but the second solve starts at the first one's terminal value?

That should be possible; however, you might have to do the looping yourself until this feature makes it into the source. (I'd like this to be natively supported, but at the moment, this is not the case.)

If you have a look at the source in solve_fixed_grid, there is a clear call to scan(); you want to take this call, do it twice with different vector fields, and then call the "make user-friendly"-style functions afterwards.

What do you think?

lahramon commented 8 months ago

Thanks for the tip! Exactly, basically I would like to chain the smoothing solves for different vector fields. I will take a look at your suggested approach and let you know how it went.

While I can imagine that this works for the filter, do you have any hopes that this would also work for the smoothing solution? How I make sure to update the filtering solution from the first vector field based on the measurements of the second vector field?

pnkraemer commented 8 months ago

Good to hear! Please keep me posted on how it goes.

If it works for the filter, it works for the smoother; the solver returns all backwards transitions, and after both forward passes, either follow up with two backward passes or tree_map(np.stack()) the posterior distributions and smooth/sample as usual.

Does that make sense?

lahramon commented 8 months ago

Yes, thanks! I will give it a try!

lahramon commented 8 months ago

Hi,

small update: I have managed to stitch together solutions with different vector fields; however, how to initialize the solver at the beginning of the non-first interval is still unresolved for me. The main problem would be that, since the derivative jumps, exact initialization is impossible. For now, I have just looked at the case where num_derivatives=1 (you will also see that the solutions are really bad for everything above that).

For the derivative update at the "kink", i.e., the place where the vector field changes, I have tried to "decouple" the two subsequent derivative values by adding a lot of noise on the diffusion matrix for the first (and higher-order) derivatives. Basically, this corresponds to assigning a very high output scale just for the derivatives of the process. What do you think about this "hack"? The code is on the fork: https://github.com/lahramon/probdiffeq/tree/stitching_solvers (currently only works for isotropic setting, for filtering and smoothing, meaning also ts1() corrections cannot be used at the moment; do you think that could help?)

Here is how the solutions look for an affine vector field:

Unmodified (blue: just first vector field, orange: restarting solver with exact initialization at kink, green: using PN solver "naively" for discontinuous vector field):

nonsmooth_probdiffeq_error_actual

New approach (blue: just first vector field, orange: restarting solver with exact initialization at kink, green: new approach with inflated prior process noise at transition):

nonsmooth_probdiffeq_new_approach

Here is the code to generate the figures:

# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.15.2
#   kernelspec:
#     display_name: venv
#     language: python
#     name: python3
# ---

# +
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps
from jax.config import config

from probdiffeq import ivpsolve, adaptive, timestep
from probdiffeq.impl import impl
from probdiffeq.util.doc_util import notebook
from probdiffeq.solvers import uncalibrated, calibrated, solution, markov
from probdiffeq.taylor import autodiff
from probdiffeq.solvers.strategies import filters, smoothers, fixedpoint
from probdiffeq.solvers.strategies.components import corrections, priors

# +
# plt.rcParams.update(notebook.plot_config())

if not backend.has_been_selected:
    backend.select("jax")  # ivp examples in jax

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
# -

ode_dim = 1
impl.select("isotropic", ode_shape=(ode_dim,))

# +
f, u0, (t0, t1), f_args = ivps.affine_independent(a=2.0, b=1.0)
f_2, u0_2, (t0_2, t1_2), f_args_2 = ivps.affine_independent(a=-10.0, b=-1.0)
# f, u0, (t0, t1), f_args = ivps.logistic(parameters=(1.0,-0.5))
# f_2, u0_2, (t0_2, t1_2), f_args_2 = ivps.logistic(parameters=(-1.0,1.0))

u0 = jnp.atleast_1d(u0)
u0_2 = jnp.atleast_1d(u0_2)

@jax.jit
def vf(ys, t):
    return f(ys, *f_args)

@jax.jit
def vf_2(ys, t):
    return f_2(ys, *f_args_2)

# -

num_points = 20
num_points_fine = 30
idx_half = num_points // 2
ts = jnp.linspace(t0, t1, endpoint=True, num=num_points)
ts_1 = ts[0:idx_half]
ts_2 = ts[idx_half:]
ts_fine = jnp.linspace(t0, t1, endpoint=True, num=num_points_fine)

# combine vector fields with piecewise function
@jax.jit
def vf_combined(*ys, t):
    return jnp.where(t < ts_2[0], vf(*ys, t), vf_2(*ys, t))

# +
num_derivatives = 1
output_scale_guess = 1e1
ibm = priors.ibm_adaptive(num_derivatives=num_derivatives)
# ibm = priors.ibm_discretised(ts, num_derivatives=num_derivatives)
plot_samples = False
use_filter = False
use_ts1 = False
use_calibrated = False
if use_ts1:
    is1 = corrections.ts1()
else:
    ts0 = corrections.ts0()

if use_filter:
    strategy = filters.filter_adaptive(ibm, ts0)
else:
    strategy = smoothers.smoother_adaptive(ibm, ts0)

if use_calibrated:
    solver = calibrated.mle(strategy)
else:
    solver = uncalibrated.solver(strategy)

# +

def solve(vf, u0, ts):

    if not isinstance(vf, list) and not isinstance(ts, list):
        tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=num_derivatives)
        init = solver.initial_condition(tcoeffs, output_scale=output_scale_guess)
        sol = ivpsolve.solve_fixed_grid(
            vf, init, grid=ts, solver=solver
        )
        # vf = [vf]
        # ts = [ts]

    elif isinstance(vf, list) and isinstance(ts, list):
        tcoeffs = autodiff.taylor_mode_scan(lambda y: vf[0](y, t=t0), (u0,), num=num_derivatives)
        init = solver.initial_condition(tcoeffs, output_scale=output_scale_guess)
        sol = ivpsolve.solve_fixed_grid_arr(
            vf, init, grid=ts, solver=solver, use_filter=use_filter
        )
    else:
        raise ValueError("vf and ts must both be lists or both not")

    marginals = solution.calibrate(sol.marginals, output_scale=sol.output_scale)

    if plot_samples and not use_filter:
        posterior = solution.calibrate(sol.posterior, output_scale=sol.output_scale)
        posterior = markov.select_terminal(posterior)

        key = jax.random.PRNGKey(seed=1)
        (qoi, samples), _init = markov.sample(key, posterior, shape=(2,), reverse=True)
    else:
        posterior, qoi, samples = None, None, None

    return sol, marginals, posterior, qoi, samples

# -

def plot(sol, marginals, posterior, qoi, samples):

    fig, axes_all = plt.subplots(
        nrows=2,
        ncols=num_derivatives+1,
        sharex=True,
        tight_layout=True,
        figsize=(8, 8),
    )

    if not isinstance(sol, list):
        sol = [sol]
        marginals = [marginals]
        posterior = [posterior]
        qoi = [qoi]
        samples = [samples]

    for sol, marginals, posterior, qoi, samples in zip(sol, marginals, posterior, qoi, samples):
        for i, axes_cols in enumerate(axes_all.T):
            ms = marginals.mean[:, i, :]
            ls = marginals.cholesky[:, i, :]
            stds = jnp.sqrt(jnp.einsum("jn,jn->j", ls, ls))

            if i == 1:
                axes_cols[0].set_title(f"{i}st deriv.")
            elif i == 2:
                axes_cols[0].set_title(f"{i}nd deriv.")
            elif i == 3:
                axes_cols[0].set_title(f"{i}rd deriv.")
            else:
                axes_cols[0].set_title(f"{i}th deriv.")

            axes_cols[0].plot(sol.t, ms, marker="None")

            if samples is not None:
                samps = samples[..., i, :]
                for s in samps:
                    axes_cols[0].plot(
                        sol.t[:-1], s[..., 0], color="C0", linewidth=0.35, marker="None"
                    )
                    axes_cols[0].plot(
                        sol.t[:-1], s[..., 1], color="C1", linewidth=0.35, marker="None"
                    )

            for m in ms.T:
                axes_cols[0].fill_between(sol.t, m - 1.96 * stds, m + 1.96 * stds, alpha=0.3)

            axes_cols[0].vlines(ts_2[0], ymin=axes_cols[0].get_ylim()[0], ymax=axes_cols[0].get_ylim()[1], color="black")
            axes_cols[0].vlines(ts, ymin=axes_cols[0].get_ylim()[0], ymax=axes_cols[0].get_ylim()[1], color="black", alpha=0.1)
            axes_cols[1].semilogy(sol.t, stds, marker="None")

    return fig

sol_1, marginals_1, posterior_1, qoi_1, samples_1 = solve(vf, u0, ts)
u0_2 = sol_1.u[idx_half,:]
sol_2, marginals_2, posterior_2, qoi_2, samples_2 = solve(vf_2, u0_2, ts_2)
fig = plot([sol_1, sol_2], [marginals_1, marginals_2], [posterior_1, posterior_2], [qoi_1, qoi_2], [samples_1, samples_2])
fig.savefig("test_probdiffeq_nonsmooth_just_first_vector_field.png")

sol_err, marginals_err, posterior_err, qoi_err, samples_err = solve(vf_combined, u0, ts)
sol_f, marginals_f, posterior_f, qoi_f, samples_f = solve(vf_combined, u0, ts_fine)
fig = plot([sol_1,sol_2,sol_err], [marginals_1,marginals_2,marginals_err], [posterior_1,posterior_2,posterior_err], [qoi_1,qoi_2,qoi_err], [samples_1,samples_2,samples_err])
# plot([sol,sol_2,sol_f], [marginals,marginals_2,marginals_f], [posterior,posterior_2,posterior_f], [qoi,qoi_2,qoi_f], [samples,samples_2,samples_f])
fig.savefig("figures/nonsmooth_probdiffeq_error_actual.png", bbox_inches="tight", dpi=300)

# +
import importlib
import probdiffeq
import probdiffeq.ivpsolve
importlib.reload(probdiffeq)
importlib.reload(probdiffeq.ivpsolve)
from probdiffeq.ivpsolve import solve_fixed_grid_arr 

vf_arr = [vf, vf_2]
ts_arr = [ts_1, ts_2]
sol, marginals, posterior, qoi, samples = solve(vf_arr, u0, ts_arr)

u0_2 = sol.u[idx_half,:]
sol_2, marginals_2, posterior_2, qoi_2, samples_2 = solve(vf_2, u0_2, ts_2)
# -

fig = plot([sol_err,sol_2,sol], [marginals_err,marginals_2,marginals], [posterior_err,posterior_2,posterior], [qoi_err,qoi_2,qoi], [samples_err,samples_2,samples])
fig.savefig("figures/nonsmooth_probdiffeq_new_approach.png", bbox_inches="tight", dpi=300)

Do you have any other ideas how to make the transition between the two vector fields cleaner?

For my personal project, I think I will first consider smooth vector fields for now, but maybe this first approach helps in finding a solution for non-smooth vector fields some time soon!

pnkraemer commented 8 months ago

Hi,

Thanks for the update!

I am not entirely sure I understand correctly. Do you want to solve an ODE with a discontinuous vector field? If so, plugging the (discontinuous) vector field into the solver should work. (In your code, that would be vf_combined).

For instance, the code below is an adaptation from the quickstart (with a discontinuous vector field and smoothing) and solves the IVP with arbitrary numbers of derivatives:

"""Solve the logistic equation."""

import jax.numpy as jnp
from jax.config import config

from probdiffeq import ivpsolve
from probdiffeq.impl import impl
from probdiffeq.solvers import uncalibrated
from probdiffeq.solvers.strategies import smoothers
from probdiffeq.solvers.strategies.components import corrections, priors
from probdiffeq.taylor import autodiff
from probdiffeq.solvers import markov
import matplotlib.pyplot as plt

# Essentially the same as in the quickstart

config.update("jax_platform_name", "cpu")

u0 = jnp.asarray([0.1])
t0, t1 = 0.0, 1.0

impl.select("dense", ode_shape=(1,))

ibm = priors.ibm_adaptive(num_derivatives=1)
ts0 = corrections.ts1(ode_order=1)

strategy = smoothers.smoother_adaptive(ibm, ts0)
solver = uncalibrated.solver(strategy)

# A discontinuous vector field 

def jump_vector_field(t_jump):
    def vf(y, *, t):
        """Evaluate the vector field."""
        ones = jnp.ones_like(y)
        return jnp.where(t < t_jump, 10.0 * ones, -10.0 * ones)

    return vf

t_jump = (t1 + t0) / 2
vf = jump_vector_field(t_jump=t_jump)

# Again, same as in quickstart

tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=1)
output_scale = 1.0  # or any other value with the same shape
init = solver.initial_condition(tcoeffs, output_scale)

dt0 = 0.1
ts = jnp.linspace(t0, t1, num=int(1 / dt0) + 1, endpoint=True)
solution = ivpsolve.solve_fixed_grid(vf, init, grid=ts, solver=solver)

# Smooth and plot (optional, of course)

posterior = markov.select_terminal(solution.posterior)
marginals = markov.marginals(posterior, reverse=True)
mean_s0, mean_s1, *_ = marginals.mean.T

fig, ax = plt.subplots(ncols=2, sharex=True, sharey=True)

ax[0].set_title("State")
ax[0].plot(ts[:-1], mean_s0)

ax[1].set_title("Derivative")
ax[1].plot(ts[:-1], mean_s1)

for a in ax:
    a.axvline(t_jump, color="black")
    a.set_xlim((t0, t1))
plt.show()

But I might have misunderstood, and you are trying to achieve something else. If this is the case, could you maybe clarify? Thanks :)

pnkraemer commented 8 months ago

An addition to the above:

My example script uses a piecewise constant vector field, yours a piecewise linear one. If the TS0 underdelivers for a piecewise non-constant vector field, switch to TS1 (or SLR1) or increase the grid-resolution. Either approach should improve the situation.