probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
667 stars 76 forks source link

Added EKF sampler #313

Closed calebweinreb closed 1 year ago

calebweinreb commented 1 year ago

This PR adds extended_kalman_posterior_sample, which has the same signature as lgssm_posterior_sample. It behaves as expected for the pendulum example in the docs:

from dynamax.nonlinear_gaussian_ssm import extended_kalman_posterior_sample

ekf_params = ParamsNLGSSM(
    initial_mean=pendulum_params.initial_state,
    initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
    dynamics_function=pendulum_params.dynamics_function,
    dynamics_covariance=pendulum_params.dynamics_covariance,
    emission_function=pendulum_params.emission_function,
    emission_covariance=pendulum_params.emission_covariance,
)

ekf_posterior = extended_kalman_smoother(ekf_params, obs)
sampled_states = extended_kalman_posterior_sample(jr.PRNGKey(0), ekf_params, obs)

m_ekf = sampled_states[:, 0]
plot_pendulum(time_grid, states[:, 0], obs, x_est=m_ekf, est_type="EKF (sampled)")
compute_and_print_rmse_comparison(states[:, 0], m_ekf, r, "EKF")
Screenshot 2023-05-13 at 2 25 22 PM

Increasing the process and observation noise 20-fold:

Screenshot 2023-05-13 at 6 07 12 PM

This PR also adds the following tests:

from dynamax.nonlinear_gaussian_ssm.inference_ekf_test import (
    test_extended_kalman_sampler_linear,
    test_extended_kalman_sampler_nonlinear)

test_extended_kalman_sampler_linear()
test_extended_kalman_sampler_nonlinear()

In the process of writing these tests, I found that lgssm_filter and extended_kalman_filter returned asymmetric covariance matrices for the default examples in dynamax.nonlinear_gaussian_ssm.inference_test (see issue https://github.com/probml/dynamax/issues/317). As a result, the outputs of lgssm_posterior_sample and extended_kalman_posterior_sample were all NaN and could not be used for testing. This asymmetry issue is addressed in two separate PRs (https://github.com/probml/dynamax/pull/318 for lgssm_filter and https://github.com/probml/dynamax/pull/319 for extended_kalman_filter). Both of those PRs need to be merged before this one or the tests won't work. So the overall merge order would be:

  1. https://github.com/probml/dynamax/pull/318 (ensures LGSSM filtered covariance is symmetric)
  2. https://github.com/probml/dynamax/pull/319 (ensures EKF filtered covariance is symmetric)
  3. https://github.com/probml/dynamax/pull/313 (the current PR that adds an EKF sampler)
slinderman commented 1 year ago

Thanks @calebweinreb! This looks good to me.

Is there a way to test this function? I guess we could check that the sample mean matches the EKF smoother mean?

calebweinreb commented 1 year ago

Here's two tests possible tests:

1. Match to LGSSM

The output of extended_kalman_posterior_sample should match lgssm_posterior_sample when the dynamics function is linear. It does in the following example.

from jax import numpy as jnp
from jax import random as jr
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM
from dynamax.nonlinear_gaussian_ssm import extended_kalman_posterior_sample
from dynamax.linear_gaussian_ssm import lgssm_posterior_sample
from dynamax.linear_gaussian_ssm.inference_test import build_lgssm_for_sampling

# simulate LGSSM
num_timesteps=100
key = jr.PRNGKey(0)
sample_key, key = jr.split(key)
lgssm, lgssm_params = build_lgssm_for_sampling()
states, emissions = lgssm.sample(lgssm_params, key=sample_key, num_timesteps=num_timesteps)

# sample from LGSSM
lgssm_sampled_states = lgssm_posterior_sample(key, lgssm_params, emissions)

# sample from EKF
ekf_params = ParamsNLGSSM(
    initial_mean=lgssm_params.initial.mean,
    initial_covariance=lgssm_params.initial.cov,
    dynamics_function=(lambda x: lgssm_params.dynamics.weights @ x),
    dynamics_covariance=lgssm_params.dynamics.cov,
    emission_function=(lambda x: lgssm_params.emissions.weights @ x),
    emission_covariance=lgssm_params.emissions.cov)
ekf_sampled_states = extended_kalman_posterior_sample(key, ekf_params, emissions)

print(jnp.allclose(lgssm_sampled_states, ekf_sampled_states)) 

2. Match to smoother

The site-wise mean and variance of the sampled states should match marginal mean and covariance from the smoother. Here's an example using the pendulum example from the docs.

pendulum_params = PendulumParams()

ekf_params = ParamsNLGSSM(
    initial_mean=pendulum_params.initial_state,
    initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
    dynamics_function=pendulum_params.dynamics_function,
    dynamics_covariance=pendulum_params.dynamics_covariance*10,
    emission_function=pendulum_params.emission_function,
    emission_covariance=pendulum_params.emission_covariance*10,
)

# smoothing
ekf_posterior = extended_kalman_smoother(ekf_params, obs)

# sampling
num_samples = 100000
sample_fun = jax.vmap(extended_kalman_posterior_sample, in_axes=(0,None,None))
samples = (jr.split(jr.PRNGKey(0), num_samples), ekf_params, obs)

# compare by plotting
fig,axs = plt.subplots(1,2)
axs[0].plot(ekf_posterior.smoothed_means[:,0], label='smoothed mean')
axs[0].plot(jnp.mean(samples[:,:,0],axis=0), label='mean of samples')
axs[0].set_ylabel('mean')
axs[0].legend(loc='upper left')
axs[1].plot(ekf_posterior.smoothed_covariances[:,0,0], label='smoothed variance')
axs[1].plot(jnp.var(samples[:,:,0],axis=0), label='variance of samples')
axs[1].legend(loc='upper left')
axs[1].set_ylabel('variance')
fig.set_size_inches((15,3))

Here's the result

Screenshot 2023-05-15 at 8 55 11 PM

The error seems reasonable given the sample size:

empirical_error = jnp.std(ekf_posterior.smoothed_means[:,0] - jnp.mean(samples[:,:,0],axis=0))
expected_error = jnp.sqrt(ekf_posterior.smoothed_covariances[:,0,0] / num_samples)
print(empirical_error, expected_error.mean())

yields

empirical_error = 0.000929
expected_error = 0.000768
murphyk commented 1 year ago

These are both great tests. Please add them to https://github.com/probml/dynamax/blob/main/dynamax/nonlinear_gaussian_ssm/inference_ekf_test.py. And change print(allclose) to assert(allclose) :)

calebweinreb commented 1 year ago

I added testing functions that are similar to the ones proposed above but use the examples generated by dynamax.nonlinear_gaussian_ssm.inference_test_utils to be consistent with all the other tests in dynamax.nonlinear_gaussian_ssm.inference_test.py. The tests can be run as follows:

from dynamax.nonlinear_gaussian_ssm.inference_ekf_test import (
    test_extended_kalman_sampler_linear,
    test_extended_kalman_sampler_nonlinear)

test_extended_kalman_sampler_linear()
test_extended_kalman_sampler_nonlinear()

In the process of writing these tests, I found that lgssm_filter and extended_kalman_filter returned asymmetric covariance matrices (see issue https://github.com/probml/dynamax/issues/317), and consequently the outputs of lgssm_posterior_sample and extended_kalman_posterior_sample were all NaN. This asymmetry issue is addressed in two separate PRs (https://github.com/probml/dynamax/pull/318 for lgssm_filter and https://github.com/probml/dynamax/pull/319 for extended_kalman_filter). Both of those PRs need to be merged before this one or the tests won't work. So the overall merge order would be:

  1. https://github.com/probml/dynamax/pull/318 (ensures LGSSM filtered covariance is symmetric)
  2. https://github.com/probml/dynamax/pull/319 (ensures EKF filtered covariance is symmetric)
  3. https://github.com/probml/dynamax/pull/313 (the current PR that adds an EKF sampler)