Closed slinderman closed 6 days ago
The comparison to sarkka_lib started failing for some reason.
sarkka_lib
=================================== FAILURES =================================== ______________________________ test_ukf_nonlinear ______________________________ key = 0, num_timesteps = 15 def test_ukf_nonlinear(key=0, num_timesteps=15): nlgssm_args, _, emissions = random_nlgssm_args(key=key, num_timesteps=num_timesteps) hyperparams = UKFHyperParams() # Run UKF from sarkka-jax library means_ukf, covs_ukf = ukf(*nlgssm_args, *hyperparams, emissions) # Run UKS from sarkka-jax library means_uks, covs_uks = uks(*nlgssm_args, *hyperparams, emissions) # Run UKS from dynamax uks_post = unscented_kalman_smoother(nlgssm_args, emissions, hyperparams) # Compare filter results assert allclose(means_ukf, uks_post.filtered_means) assert allclose(covs_ukf, uks_post.filtered_covariances) > assert allclose(means_uks, uks_post.smoothed_means) E assert Array(False, dtype=bool) E + where Array(False, dtype=bool) = allclose(Array([[-9.5265403e-02, 9.9717736e-01, -2.9405496e-01, -7.4380[30](https://github.com/probml/dynamax/actions/runs/9669195152/job/26721427695?pr=365#step:5:31)2e-01],\n [-1.4460543e-01, 3.7296972e-01, 5.36...864e-01, -4.4984221e-01],\n [-5.4547715e-01, -5.2458370e-01, 2.7398711e-01, 6.1779696e-01]], dtype=float[32](https://github.com/probml/dynamax/actions/runs/9669195152/job/26721427695?pr=365#step:5:33)), Array([[-1.4460546e-01, 3.7296972e-01, 5.3632960e-02, -1.3270415e-02],\n [-4.1872060e-01, -5.0879848e-01, -3.60...661e-01, 6.1779743e-01],\n [-9.5266208e-02, 9.9717760e-01, -2.9405388e-01, -7.4380308e-01]], dtype=float32)) E + where Array([[-1.4460546e-01, 3.7296972e-01, 5.3632960e-02, -1.3270415e-02],\n [-4.1872060e-01, -5.0879848e-01, -3.60...661e-01, 6.1779743e-01],\n [-9.5266208e-02, 9.9717760e-01, -2.9405388e-01, -7.4380308e-01]], dtype=float32) = PosteriorGSSMSmoothed(marginal_loglik=Array(-53.798016, dtype=float32), filtered_means=Array([[-0.11624812, 0.4089546...73e-03, 1.28583722e-02, -8.81038457e-02,\n 1.058710[34](https://github.com/probml/dynamax/actions/runs/9669195152/job/26721427695?pr=365#step:5:35)e+00]]], dtype=float32), smoothed_cross_covariances=None).smoothed_means dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py:32: AssertionError ----------------------------- Captured stdout call ----------------------------- 9.834766e-07 5.9604645e-07 1.5878675
Sorry, it was actually a failure in my logic. I fixed it here: https://github.com/probml/dynamax/pull/368
The comparison to
sarkka_lib
started failing for some reason.