probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

Symmetrize LGSSM and EKF filtered covariance #319

Closed calebweinreb closed 1 year ago

calebweinreb commented 1 year ago

This PR addresses numerical instability in EKF inference and LGSSM inference (issue https://github.com/probml/dynamax/issues/317) by ensuring the covariance matrices output by extended_kalman_filter and lgssm_filter are symmetric. In both cases this is done by forcibly symmetrizing the output of _condition_on.

The following EKF and LGSSM inference tests pass:

from dynamax.nonlinear_gaussian_ssm.inference_ekf_test import (
    test_extended_kalman_filter_linear,
    test_extended_kalman_filter_nonlinear,
    test_extended_kalman_smoother_linear,
    extended_kalman_smoother_nonlinear)

test_extended_kalman_filter_linear()
test_extended_kalman_filter_nonlinear()
test_extended_kalman_smoother_linear()
extended_kalman_smoother_nonlinear()

from dynamax.linear_gaussian_ssm.inference_test import TestFilteringAndSmoothing

TestFilteringAndSmoothing.test_kalman_tfp(TestFilteringAndSmoothing)
TestFilteringAndSmoothing.test_kalman_vs_joint(TestFilteringAndSmoothing)
TestFilteringAndSmoothing.test_posterior_sampler(TestFilteringAndSmoothing)