pnkraemer / probdiffeq

Probabilistic solvers for differential equations in JAX. Adaptive ODE solvers with calibration, state-space model factorisations, and custom information operators. Compatible with the broader JAX scientific computing ecosystem.
https://pnkraemer.github.io/probdiffeq/
MIT License
30 stars 2 forks source link

Shifted updating timestamp for vector field #677

Closed lahramon closed 10 months ago

lahramon commented 10 months ago

Hi,

related to issure #676, I have been trying to use the library for non-smooth vector fields. As a test, I used the vector field

$$ \dot{x}(t) = \begin{cases} 1, > \text{if $t < t{N/2}$} \ -1, > \text{if $t \geq t{N/2}$} \ \end{cases} $$

and got the following solution (you can disregard the orange lines):

nonsmooth_probdiffeq_error_actual

While it was clear that integrating the non-differentiable vector field would lead to integration errors, I wondered about the time step at which the solution is updated for the first time with the new vector field observation (this would also be relevant for smooth, time-varying vector fields I guess). My expectation would have been that the jump in the gradient takes place at $t = N/2$; however, the gradient is updated only at $t = N/2 +1$. Here is what I would have expected to see (generated with my vanilla, non-optimized PN solver implementation):

nonsmooth_pnocp_error_expected

Looking into the code, I found that switching the two lines here (and using t=t with the updated time as for the correction time) would solve the issue. What do you think?

Here is the code to reproduce the figure:

# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.15.2
#   kernelspec:
#     display_name: venv
#     language: python
#     name: python3
# ---

# +
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps
from jax.config import config

from probdiffeq import ivpsolve, adaptive, timestep
from probdiffeq.impl import impl
from probdiffeq.util.doc_util import notebook
from probdiffeq.solvers import uncalibrated, calibrated, solution, markov
from probdiffeq.taylor import autodiff
from probdiffeq.solvers.strategies import filters, smoothers, fixedpoint
from probdiffeq.solvers.strategies.components import corrections, priors

# +
plt.rcParams.update(notebook.plot_config())

if not backend.has_been_selected:
    backend.select("jax")  # ivp examples in jax

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
# -

ode_dim = 1
impl.select("isotropic", ode_shape=(ode_dim,))

# +
f, u0, (t0, t1), f_args = ivps.affine_independent(a=0.0, b=1.0)
f_2, u0_2, (t0_2, t1_2), f_args_2 = ivps.affine_independent(a=0.0, b=-1.0)

u0 = jnp.atleast_1d(u0)
u0_2 = jnp.atleast_1d(u0_2)

@jax.jit
def vf(ys, t):
    return f(ys, *f_args)

@jax.jit
def vf_2(ys, t):
    return f_2(ys, *f_args_2)

# -

num_points = 10
idx_half = num_points // 2
ts = jnp.linspace(t0, t1, endpoint=True, num=num_points)
ts_1 = ts[0:idx_half+1]
ts_2 = ts[idx_half:]

# combine vector fields with piecewise function
@jax.jit
def vf_combined(*ys, t):
    return jnp.where(t < ts_2[0], vf(*ys, t), vf_2(*ys, t))

# +
num_derivatives = 1
output_scale_guess = 1.0
ibm = priors.ibm_adaptive(num_derivatives=num_derivatives)
# ibm = priors.ibm_discretised(ts, num_derivatives=num_derivatives)
use_filter = True
use_ts1 = False
use_calibrated = False
if use_ts1:
    is1 = corrections.ts1()
else:
    ts0 = corrections.ts0()

if use_filter:
    strategy = filters.filter_adaptive(ibm, ts0)
else:
    strategy = smoothers.smoother_adaptive(ibm, ts0)

if use_calibrated:
    solver = calibrated.mle(strategy)
else:
    solver = uncalibrated.solver(strategy)

# -

def solve(vf, u0, ts):
    tcoeffs = autodiff.taylor_mode_scan(lambda y: vf(y, t=t0), (u0,), num=num_derivatives)
    init = solver.initial_condition(tcoeffs, output_scale=output_scale_guess)
    sol = ivpsolve.solve_fixed_grid(
        vf, init, grid=ts, solver=solver
    )

    marginals = solution.calibrate(sol.marginals, output_scale=sol.output_scale)

    if not use_filter:
        posterior = solution.calibrate(sol.posterior, output_scale=sol.output_scale)
        posterior = markov.select_terminal(posterior)

        key = jax.random.PRNGKey(seed=1)
        (qoi, samples), _init = markov.sample(key, posterior, shape=(2,), reverse=True)
    else:
        posterior, qoi, samples = None, None, None

    return sol, marginals, posterior, qoi, samples

def plot(sol, marginals, posterior, qoi, samples):

    fig, axes_all = plt.subplots(
        nrows=2,
        ncols=num_derivatives+1,
        sharex=True,
        tight_layout=True,
        figsize=(8, 8),
    )

    if not isinstance(sol, list):
        sol = [sol]
        marginals = [marginals]
        posterior = [posterior]
        qoi = [qoi]
        samples = [samples]

    for sol, marginals, posterior, qoi, samples in zip(sol, marginals, posterior, qoi, samples):
        for i, axes_cols in enumerate(axes_all.T):
            ms = marginals.mean[:, i, :]
            ls = marginals.cholesky[:, i, :]
            stds = jnp.sqrt(jnp.einsum("jn,jn->j", ls, ls))

            if i == 1:
                axes_cols[0].set_title(f"{i}st deriv.")
            elif i == 2:
                axes_cols[0].set_title(f"{i}nd deriv.")
            elif i == 3:
                axes_cols[0].set_title(f"{i}rd deriv.")
            else:
                axes_cols[0].set_title(f"{i}th deriv.")

            axes_cols[0].plot(sol.t, ms, marker="None")

            if samples is not None:
                samps = samples[..., i, :]
                for s in samps:
                    axes_cols[0].plot(
                        sol.t[:-1], s[..., 0], color="C0", linewidth=0.35, marker="None"
                    )
                    axes_cols[0].plot(
                        sol.t[:-1], s[..., 1], color="C1", linewidth=0.35, marker="None"
                    )

            for m in ms.T:
                axes_cols[0].fill_between(sol.t, m - 1.96 * stds, m + 1.96 * stds, alpha=0.3)

            axes_cols[0].vlines(ts_2[0], ymin=axes_cols[0].get_ylim()[0], ymax=axes_cols[0].get_ylim()[1], color="black")
            axes_cols[0].vlines(ts, ymin=axes_cols[0].get_ylim()[0], ymax=axes_cols[0].get_ylim()[1], color="black", alpha=0.1)
            axes_cols[1].semilogy(sol.t, stds, marker="None")

    plt.show()
    return fig

# +
sol, marginals, posterior, qoi, samples = solve(vf_combined, u0, ts)

u0_2 = sol.u[idx_half,:]
sol_2, marginals_2, posterior_2, qoi_2, samples_2 = solve(vf_2, u0_2, ts_2)
fig = plot([sol,sol_2], [marginals,marginals_2], [posterior,posterior_2], [qoi,qoi_2], [samples,samples_2])
# plot([sol,sol_2,sol_f], [marginals,marginals_2,marginals_f], [posterior,posterior_2,posterior_f], [qoi,qoi_2,qoi_f], [samples,samples_2,samples_f])
fig.savefig("figures/nonsmooth_smoothing_error.png", bbox_inches="tight", dpi=300)

Thanks, your insight would be much appreciated!

pnkraemer commented 10 months ago

Looking into the code, I found that switching the two lines here (and using t=t with the updated time as for the correction time) would solve the issue. What do you think?

Sounds like you found an error in the source. Good catch!

How do you feel about a quick pull request? Your proposed solution sounds good :+1:

lahramon commented 10 months ago

Okay, I'll try!

pnkraemer commented 10 months ago

I released a new version. Your changes should be available with pip install --upgrade probdiffeq now.

lahramon commented 10 months ago

Thanks a lot for your quick action!