Closed calebweinreb closed 1 year ago
Thanks @calebweinreb! This looks good to me.
Is there a way to test this function? I guess we could check that the sample mean matches the EKF smoother mean?
Here's two tests possible tests:
The output of extended_kalman_posterior_sample
should match lgssm_posterior_sample
when the dynamics function is linear. It does in the following example.
from jax import numpy as jnp
from jax import random as jr
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM
from dynamax.nonlinear_gaussian_ssm import extended_kalman_posterior_sample
from dynamax.linear_gaussian_ssm import lgssm_posterior_sample
from dynamax.linear_gaussian_ssm.inference_test import build_lgssm_for_sampling
# simulate LGSSM
num_timesteps=100
key = jr.PRNGKey(0)
sample_key, key = jr.split(key)
lgssm, lgssm_params = build_lgssm_for_sampling()
states, emissions = lgssm.sample(lgssm_params, key=sample_key, num_timesteps=num_timesteps)
# sample from LGSSM
lgssm_sampled_states = lgssm_posterior_sample(key, lgssm_params, emissions)
# sample from EKF
ekf_params = ParamsNLGSSM(
initial_mean=lgssm_params.initial.mean,
initial_covariance=lgssm_params.initial.cov,
dynamics_function=(lambda x: lgssm_params.dynamics.weights @ x),
dynamics_covariance=lgssm_params.dynamics.cov,
emission_function=(lambda x: lgssm_params.emissions.weights @ x),
emission_covariance=lgssm_params.emissions.cov)
ekf_sampled_states = extended_kalman_posterior_sample(key, ekf_params, emissions)
print(jnp.allclose(lgssm_sampled_states, ekf_sampled_states))
The site-wise mean and variance of the sampled states should match marginal mean and covariance from the smoother. Here's an example using the pendulum example from the docs.
pendulum_params = PendulumParams()
ekf_params = ParamsNLGSSM(
initial_mean=pendulum_params.initial_state,
initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
dynamics_function=pendulum_params.dynamics_function,
dynamics_covariance=pendulum_params.dynamics_covariance*10,
emission_function=pendulum_params.emission_function,
emission_covariance=pendulum_params.emission_covariance*10,
)
# smoothing
ekf_posterior = extended_kalman_smoother(ekf_params, obs)
# sampling
num_samples = 100000
sample_fun = jax.vmap(extended_kalman_posterior_sample, in_axes=(0,None,None))
samples = (jr.split(jr.PRNGKey(0), num_samples), ekf_params, obs)
# compare by plotting
fig,axs = plt.subplots(1,2)
axs[0].plot(ekf_posterior.smoothed_means[:,0], label='smoothed mean')
axs[0].plot(jnp.mean(samples[:,:,0],axis=0), label='mean of samples')
axs[0].set_ylabel('mean')
axs[0].legend(loc='upper left')
axs[1].plot(ekf_posterior.smoothed_covariances[:,0,0], label='smoothed variance')
axs[1].plot(jnp.var(samples[:,:,0],axis=0), label='variance of samples')
axs[1].legend(loc='upper left')
axs[1].set_ylabel('variance')
fig.set_size_inches((15,3))
Here's the result
The error seems reasonable given the sample size:
empirical_error = jnp.std(ekf_posterior.smoothed_means[:,0] - jnp.mean(samples[:,:,0],axis=0))
expected_error = jnp.sqrt(ekf_posterior.smoothed_covariances[:,0,0] / num_samples)
print(empirical_error, expected_error.mean())
yields
empirical_error = 0.000929
expected_error = 0.000768
These are both great tests. Please add them to https://github.com/probml/dynamax/blob/main/dynamax/nonlinear_gaussian_ssm/inference_ekf_test.py.
And change print(allclose)
to assert(allclose)
:)
I added testing functions that are similar to the ones proposed above but use the examples generated by dynamax.nonlinear_gaussian_ssm.inference_test_utils
to be consistent with all the other tests in dynamax.nonlinear_gaussian_ssm.inference_test.py
. The tests can be run as follows:
from dynamax.nonlinear_gaussian_ssm.inference_ekf_test import (
test_extended_kalman_sampler_linear,
test_extended_kalman_sampler_nonlinear)
test_extended_kalman_sampler_linear()
test_extended_kalman_sampler_nonlinear()
In the process of writing these tests, I found that lgssm_filter
and extended_kalman_filter
returned asymmetric covariance matrices (see issue https://github.com/probml/dynamax/issues/317), and consequently the outputs of lgssm_posterior_sample
and extended_kalman_posterior_sample
were all NaN. This asymmetry issue is addressed in two separate PRs (https://github.com/probml/dynamax/pull/318 for lgssm_filter
and https://github.com/probml/dynamax/pull/319 for extended_kalman_filter
). Both of those PRs need to be merged before this one or the tests won't work. So the overall merge order would be:
This PR adds
extended_kalman_posterior_sample
, which has the same signature aslgssm_posterior_sample
. It behaves as expected for the pendulum example in the docs:Increasing the process and observation noise 20-fold:
This PR also adds the following tests:
In the process of writing these tests, I found that
lgssm_filter
andextended_kalman_filter
returned asymmetric covariance matrices for the default examples indynamax.nonlinear_gaussian_ssm.inference_test
(see issue https://github.com/probml/dynamax/issues/317). As a result, the outputs oflgssm_posterior_sample
andextended_kalman_posterior_sample
were all NaN and could not be used for testing. This asymmetry issue is addressed in two separate PRs (https://github.com/probml/dynamax/pull/318 forlgssm_filter
and https://github.com/probml/dynamax/pull/319 forextended_kalman_filter
). Both of those PRs need to be merged before this one or the tests won't work. So the overall merge order would be: