probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

UKS Inference tests are failing #367

Closed slinderman closed 6 days ago

slinderman commented 6 days ago

The comparison to sarkka_lib started failing for some reason.

=================================== FAILURES ===================================
______________________________ test_ukf_nonlinear ______________________________

key = 0, num_timesteps = 15

    def test_ukf_nonlinear(key=0, num_timesteps=15):
        nlgssm_args, _, emissions = random_nlgssm_args(key=key, num_timesteps=num_timesteps)
        hyperparams = UKFHyperParams()

        # Run UKF from sarkka-jax library
        means_ukf, covs_ukf = ukf(*nlgssm_args, *hyperparams, emissions)
        # Run UKS from sarkka-jax library
        means_uks, covs_uks = uks(*nlgssm_args, *hyperparams, emissions)
        # Run UKS from dynamax
        uks_post = unscented_kalman_smoother(nlgssm_args, emissions, hyperparams)

        # Compare filter results
        assert allclose(means_ukf, uks_post.filtered_means)
        assert allclose(covs_ukf, uks_post.filtered_covariances)
>       assert allclose(means_uks, uks_post.smoothed_means)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = allclose(Array([[-9.5265403e-02,  9.9717736e-01, -2.9405496e-01, -7.4380[30](https://github.com/probml/dynamax/actions/runs/9669195152/job/26721427695?pr=365#step:5:31)2e-01],\n       [-1.4460543e-01,  3.7296972e-01,  5.36...864e-01, -4.4984221e-01],\n       [-5.4547715e-01, -5.2458370e-01,  2.7398711e-01,  6.1779696e-01]],      dtype=float[32](https://github.com/probml/dynamax/actions/runs/9669195152/job/26721427695?pr=365#step:5:33)), Array([[-1.4460546e-01,  3.7296972e-01,  5.3632960e-02, -1.3270415e-02],\n       [-4.1872060e-01, -5.0879848e-01, -3.60...661e-01,  6.1779743e-01],\n       [-9.5266208e-02,  9.9717760e-01, -2.9405388e-01, -7.4380308e-01]],      dtype=float32))
E        +    where Array([[-1.4460546e-01,  3.7296972e-01,  5.3632960e-02, -1.3270415e-02],\n       [-4.1872060e-01, -5.0879848e-01, -3.60...661e-01,  6.1779743e-01],\n       [-9.5266208e-02,  9.9717760e-01, -2.9405388e-01, -7.4380308e-01]],      dtype=float32) = PosteriorGSSMSmoothed(marginal_loglik=Array(-53.798016, dtype=float32), filtered_means=Array([[-0.11624812,  0.4089546...73e-03,  1.28583722e-02, -8.81038457e-02,\n          1.058710[34](https://github.com/probml/dynamax/actions/runs/9669195152/job/26721427695?pr=365#step:5:35)e+00]]], dtype=float32), smoothed_cross_covariances=None).smoothed_means

dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py:32: AssertionError
----------------------------- Captured stdout call -----------------------------
9.834766e-07
5.9604645e-07
1.5878675
edeno commented 6 days ago

Sorry, it was actually a failure in my logic. I fixed it here: https://github.com/probml/dynamax/pull/368