Open cbrummitt opened 5 months ago
Based on https://github.com/google/jax/issues/763 I see I can pass a tuple to xs
, like so:
_, (states, observations) = lax.scan(
f=_step,
init=params.initial_state,
xs=(rngs, u)
)
Then the _step
function just needs to unpack its second argument x
as rng, u = x
.
To learn how to add exogenous, time-dependent inputs, I'm considering a model that has torque $u(t)$ applied: $$\frac{d^2 \alpha}{dt^2} = - g \sin(\alpha) + w(t) + u(t).$$
I added u
to the dynamics function:
dynamics_function: Callable = lambda x, u: jnp.array(
[x[0] + x[1] * dt + u[0],
x[1] - g * jnp.sin(x[0]) * dt + u[1]
]
)
I set $u(t) = 0.1 \sin(2 \pi t)$, as if the pendulum were on a boat. Below is a stand-alone script that incorporates the above into generating the true value of the angle:
%matplotlib inline
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax.random as jr
from jax import lax
from jaxtyping import Float, Array
from typing import Callable, NamedTuple
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM, UKFHyperParams
from dynamax.nonlinear_gaussian_ssm import extended_kalman_smoother, unscented_kalman_smoother
dt = 0.0125
g = 9.8
q_c = 1
r = 0.3
class PendulumParamsWithInput(NamedTuple):
initial_state: Float[Array, "state_dim"] = jnp.array(
[jnp.pi / 2, 0])
dynamics_function: Callable = lambda x, u: jnp.array(
[
x[0] + x[1] * dt + u[0],
x[1] - g * jnp.sin(x[0]) * dt + u[1]
]
)
dynamics_covariance: Float[Array, "state_dim state_dim"] = jnp.array(
[
[q_c * dt**3 / 3, q_c * dt**2 / 2],
[q_c * dt**2 / 2, q_c * dt]
]
)
# the emission is also perturbed by u[0]:
emission_function: Callable = lambda x, u: jnp.array([jnp.sin(x[0])]) + u[0]
emission_covariance: Float[Array, "emission_dim"] = jnp.eye(1) * (r**2)
# Torque applied as a function of time, e.g., u(t) = sin(2πt)
torque: Callable = lambda time_step: 0.1 * jnp.sin(2 * jnp.pi * time_step * dt)
# Pendulum simulation (Särkkä Example 3.7)
def simulate_pendulum_with_input(
params, key=0, num_steps=400
):
if isinstance(key, int):
key = jr.PRNGKey(key)
# Unpack parameters
M, N = params.initial_state.shape[0], params.emission_covariance.shape[0]
f, h = params.dynamics_function, params.emission_function
Q, R = params.dynamics_covariance, params.emission_covariance
u = jnp.hstack(
[
jnp.zeros((num_steps, 1)),
params.torque(jnp.arange(num_steps)).reshape(-1, 1),
]
)
def _step(carry, rng_u):
state = carry
rng, u = rng_u
rng1, rng2 = jr.split(rng, 2)
next_state = f(state, u) + jr.multivariate_normal(rng1, jnp.zeros(M), Q)
obs = h(next_state, u) + jr.multivariate_normal(rng2, jnp.zeros(N), R)
return next_state, (next_state, obs)
rngs = jr.split(key, num_steps)
_, (states, observations) = lax.scan(_step, params.initial_state, (rngs, u))
return states, observations
params = PendulumParamsWithInput()
num_steps = 400
states, observations = simulate_pendulum_with_input(params, num_steps=num_steps)
time_grid = jnp.arange(num_steps) * dt
u = params.torque(jnp.arange(num_steps))
plt.plot(time_grid, states[:, 0], marker='o', markersize=2, label=r'angle $\alpha(t)$')
plt.plot(time_grid, u, label='torque $u(t)$')
plt.scatter(time_grid, observations[:, 0], facecolors='none', edgecolors='k', label='observations')
plt.xlabel('time $t$')
plt.legend()
I'm not sure why the measurements diverge from the true angle:
Trying the extended_kalman_smoother
on this data
ekf_params = ParamsNLGSSM(
initial_mean=params.initial_state,
initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
dynamics_function=params.dynamics_function,
dynamics_covariance=params.dynamics_covariance,
emission_function=params.emission_function,
emission_covariance=params.emission_covariance,
)
ekf_posterior = extended_kalman_smoother(ekf_params, observations)
is giving an error I haven't yet figured out:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[22], line 10
1 ekf_params = ParamsNLGSSM(
2 initial_mean=params.initial_state,
3 initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
(...)
7 emission_covariance=params.emission_covariance,
8 )
---> 10 ekf_posterior = extended_kalman_smoother(ekf_params, observations)
File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:206, in extended_kalman_smoother(params, emissions, filtered_posterior, inputs)
204 # Get filtered posterior
205 if filtered_posterior is None:
--> 206 filtered_posterior = extended_kalman_filter(params, emissions, inputs=inputs)
207 ll = filtered_posterior.marginal_loglik
208 filtered_means = filtered_posterior.filtered_means
File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:153, in extended_kalman_filter(params, emissions, num_iter, inputs, output_fields)
151 # Run the extended Kalman filter
152 carry = (0.0, params.initial_mean, params.initial_covariance)
--> 153 (ll, *_), outputs = lax.scan(_step, carry, jnp.arange(num_timesteps))
154 outputs = {"marginal_loglik": ll, **outputs}
155 posterior_filtered = PosteriorGSSMFiltered(
156 **outputs,
157 )
[... skipping hidden 9 frame]
File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:129, in extended_kalman_filter.<locals>._step(carry, t)
126 y = emissions[t]
128 # Update the log likelihood
--> 129 H_x = H(pred_mean, u)
130 ll += MVN(h(pred_mean, u), H_x @ pred_cov @ H_x.T + R).log_prob(jnp.atleast_1d(y))
132 # Condition on this emission
File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/dynamax/nonlinear_gaussian_ssm/inference_ekf.py:16, in <lambda>(x, y)
14 # Helper functions
15 _get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x
---> 16 _process_fn = lambda f, u: (lambda x, y: f(x)) if u is None else f
17 _process_input = lambda x, y: jnp.zeros((y,1)) if x is None else x
20 def _predict(m, P, f, F, Q, u):
[... skipping hidden 5 frame]
File ~/.pyenv/versions/3.10/envs/venv.data_assimilation/lib/python3.10/site-packages/jax/_src/linear_util.py:191, in WrappedFun.call_wrapped(self, *args, **kwargs)
188 gen = gen_static_args = out_store = None
190 try:
--> 191 ans = self.f(*args, **dict(self.params, **kwargs))
192 except:
193 # Some transformations yield from inside context managers, so we have to
194 # interrupt them before reraising the exception. Otherwise they will only
195 # get garbage-collected at some later time, running their cleanup tasks
196 # only after this exception is handled, which can corrupt the global
197 # state.
198 while stack:
TypeError: PendulumParamsWithInput.<lambda>() missing 1 required positional argument: 'u'
Does adding an example with $u(t)$ to the tutorial Tracking a 1d pendulum using Extended / Unscented Kalman filter/ smoother seem of potential interest?
Hi! Did you manage to make it any further on this problem?
Hi! No, I haven't.
Does anyone have any tips on how to modify the example of the extended Kalman filter/smoother to have an input, such as $u(t) = 0.1 \cos(t)$?
I see that the type annotation of the
dynamics_function
argument ofParamsNLGSSM
allows for it to map the state and input to the state: f(x, u) = x_next. Similarly, theemission_function
argument can map the state and input, (x, u), to the measurement.We should be able to replicate the current version of that example with the trivial input of $u(t) \equiv 0$. To that end, I tried to combine the random seeds
rngs
with a pre-computed array of inputsu = jnp.zeros((num_steps, 2))
and pass that into thexs
argument oflax.scan
, but I'm struggling to get the_step
to work. Is this at all on the right track, to pass the inputu
by attaching it to therngs
in thexs
argument oflax.scan
?