probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

Example of EKF with an input? #349

Open cbrummitt opened 5 months ago

cbrummitt commented 5 months ago

Does anyone have any tips on how to modify the example of the extended Kalman filter/smoother to have an input, such as $u(t) = 0.1 \cos(t)$?

I see that the type annotation of the dynamics_function argument of ParamsNLGSSM allows for it to map the state and input to the state: f(x, u) = x_next. Similarly, the emission_function argument can map the state and input, (x, u), to the measurement.

We should be able to replicate the current version of that example with the trivial input of $u(t) \equiv 0$. To that end, I tried to combine the random seeds rngs with a pre-computed array of inputs u = jnp.zeros((num_steps, 2)) and pass that into the xs argument of lax.scan, but I'm struggling to get the _step to work. Is this at all on the right track, to pass the input u by attaching it to the rngs in the xs argument of lax.scan?

cbrummitt commented 5 months ago

Based on https://github.com/google/jax/issues/763 I see I can pass a tuple to xs, like so:

_, (states, observations) = lax.scan(
    f=_step,
    init=params.initial_state,
    xs=(rngs, u)
)

Then the _step function just needs to unpack its second argument x as rng, u = x.

To learn how to add exogenous, time-dependent inputs, I'm considering a model that has torque $u(t)$ applied: $$\frac{d^2 \alpha}{dt^2} = - g \sin(\alpha) + w(t) + u(t).$$

I added u to the dynamics function:

    dynamics_function: Callable = lambda x, u: jnp.array(
        [x[0] + x[1] * dt + u[0],
         x[1] - g * jnp.sin(x[0]) * dt + u[1]
        ]
    )

I set $u(t) = 0.1 \sin(2 \pi t)$, as if the pendulum were on a boat. Below is a stand-alone script that incorporates the above into generating the true value of the angle:

%matplotlib inline
import matplotlib.pyplot as plt

import jax.numpy as jnp
import jax.random as jr
from jax import lax
from jaxtyping import Float, Array
from typing import Callable, NamedTuple

from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, UKFHyperParams
from dynamax.nonlinear_gaussian_ssm import extended_kalman_smoother, unscented_kalman_smoother

dt = 0.0125
g = 9.8
q_c = 1
r = 0.3

class PendulumParamsWithInput(NamedTuple):
    initial_state: Float[Array, "state_dim"] = jnp.array(
        [jnp.pi / 2, 0])
    dynamics_function: Callable = lambda x, u: jnp.array(
        [
            x[0] + x[1] * dt + u[0],
            x[1] - g * jnp.sin(x[0]) * dt + u[1]
        ]
    )
    dynamics_covariance: Float[Array, "state_dim state_dim"] = jnp.array(
        [
            [q_c * dt**3 / 3, q_c * dt**2 / 2],
            [q_c * dt**2 / 2, q_c * dt]
        ]
    )
    # the emission is also perturbed by u[0]:
    emission_function: Callable = lambda x, u: jnp.array([jnp.sin(x[0])]) + u[0]
    emission_covariance: Float[Array, "emission_dim"] = jnp.eye(1) * (r**2)

    # Torque applied as a function of time, e.g., u(t) = sin(2πt)
    torque: Callable = lambda time_step: 0.1 * jnp.sin(2 * jnp.pi * time_step * dt)

# Pendulum simulation (Särkkä Example 3.7)
def simulate_pendulum_with_input(
    params, key=0, num_steps=400
):
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    # Unpack parameters
    M, N = params.initial_state.shape[0], params.emission_covariance.shape[0]
    f, h = params.dynamics_function, params.emission_function
    Q, R = params.dynamics_covariance, params.emission_covariance
    u = jnp.hstack(
        [
            jnp.zeros((num_steps, 1)),
            params.torque(jnp.arange(num_steps)).reshape(-1, 1),
        ]
    )

    def _step(carry, rng_u):
        state = carry
        rng, u = rng_u
        rng1, rng2 = jr.split(rng, 2)

        next_state = f(state, u) + jr.multivariate_normal(rng1, jnp.zeros(M), Q)
        obs = h(next_state, u) + jr.multivariate_normal(rng2, jnp.zeros(N), R)
        return next_state, (next_state, obs)

    rngs = jr.split(key, num_steps)
    _, (states, observations) = lax.scan(_step, params.initial_state, (rngs, u))
    return states, observations

params = PendulumParamsWithInput()
num_steps = 400
states, observations = simulate_pendulum_with_input(params, num_steps=num_steps)

time_grid = jnp.arange(num_steps) * dt
u = params.torque(jnp.arange(num_steps))
plt.plot(time_grid, states[:, 0], marker='o', markersize=2, label=r'angle $\alpha(t)$')
plt.plot(time_grid, u, label='torque $u(t)$')
plt.scatter(time_grid, observations[:, 0], facecolors='none', edgecolors='k', label='observations')
plt.xlabel('time $t$')
plt.legend()

I'm not sure why the measurements diverge from the true angle: image

Trying the extended_kalman_smoother on this data

ekf_params = ParamsNLGSSM(
    initial_mean=params.initial_state,
    initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
    dynamics_function=params.dynamics_function,
    dynamics_covariance=params.dynamics_covariance,
    emission_function=params.emission_function,
    emission_covariance=params.emission_covariance,
)

ekf_posterior = extended_kalman_smoother(ekf_params, observations)

is giving an error I haven't yet figured out:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[22], line 10
      1 ekf_params = ParamsNLGSSM(
      2     initial_mean=params.initial_state,
      3     initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
   (...)
      7     emission_covariance=params.emission_covariance,
      8 )
---> 10 ekf_posterior = extended_kalman_smoother(ekf_params, observations)

File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:206, in extended_kalman_smoother(params, emissions, filtered_posterior, inputs)
    204 # Get filtered posterior
    205 if filtered_posterior is None:
--> 206     filtered_posterior = extended_kalman_filter(params, emissions, inputs=inputs)
    207 ll = filtered_posterior.marginal_loglik
    208 filtered_means = filtered_posterior.filtered_means

File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:153, in extended_kalman_filter(params, emissions, num_iter, inputs, output_fields)
    151 # Run the extended Kalman filter
    152 carry = (0.0, params.initial_mean, params.initial_covariance)
--> 153 (ll, *_), outputs = lax.scan(_step, carry, jnp.arange(num_timesteps))
    154 outputs = {"marginal_loglik": ll, **outputs}
    155 posterior_filtered = PosteriorGSSMFiltered(
    156     **outputs,
    157 )

    [... skipping hidden 9 frame]

File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:129, in extended_kalman_filter.<locals>._step(carry, t)
    126 y = emissions[t]
    128 # Update the log likelihood
--> 129 H_x = H(pred_mean, u)
    130 ll += MVN(h(pred_mean, u), H_x @ pred_cov @ H_x.T + R).log_prob(jnp.atleast_1d(y))
    132 # Condition on this emission

File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:16, in <lambda>(x, y)
     14 # Helper functions
     15 _get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x
---> 16 _process_fn = lambda f, u: (lambda x, y: f(x)) if u is None else f
     17 _process_input = lambda x, y: jnp.zeros((y,1)) if x is None else x
     20 def _predict(m, P, f, F, Q, u):

    [... skipping hidden 5 frame]

File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/jax/_src/linear_util.py:191, in WrappedFun.call_wrapped(self, *args, **kwargs)
    188 gen = gen_static_args = out_store = None
    190 try:
--> 191   ans = self.f(*args, **dict(self.params, **kwargs))
    192 except:
    193   # Some transformations yield from inside context managers, so we have to
    194   # interrupt them before reraising the exception. Otherwise they will only
    195   # get garbage-collected at some later time, running their cleanup tasks
    196   # only after this exception is handled, which can corrupt the global
    197   # state.
    198   while stack:

TypeError: PendulumParamsWithInput.<lambda>() missing 1 required positional argument: 'u'

Does adding an example with $u(t)$ to the tutorial Tracking a 1d pendulum using Extended / Unscented Kalman filter/ smoother seem of potential interest?

matthewrhysjones commented 3 weeks ago

Hi! Did you manage to make it any further on this problem?

cbrummitt commented 2 weeks ago

Hi! No, I haven't.