probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
636 stars 69 forks source link

implement explicit MVN #300

Closed kostastsa closed 1 year ago

kostastsa commented 1 year ago

Reference #229 Implemented large MVN and added test comparing to Kalman smoother. Feedback welcome!

slinderman commented 1 year ago

Thanks for contributing this! I'll do a code review asap, but at first glance this looks like exactly what we need!

gileshd commented 1 year ago

Thanks for getting involved!

It might be useful to separate this test from the tfp-based tests (to give e.g. test_kalman_tfp() and test_kalman_vs_joint()) so that we get the results from the joint test even if the tfp test fails.

If you're up for it the ideal thing might be to refactor the tests to use the class-based approach as in e.g. info_infererence_test.py and parallel_inference_test.py (described here). No worries at all if that's not possible, I'm very happy to do it after this gets merged.

murphyk commented 1 year ago

+1 to splitting the test into 2 - one for TFF, one for your code.

kostastsa commented 1 year ago

Sounds good, I'm on it!

kostastsa commented 1 year ago

I have done the refactor, following the pattern used in info_inference_tests.py. Let me know how it looks!

murphyk commented 1 year ago

LGTM.

murphyk commented 1 year ago

I will merge despite the failing test, which is unrelated to this PR (has something to do with tqdm).