probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

Parallel inference of LGSSM in the EM algorithm (+ some bug fixes) #336

Open kstoneriv3 opened 10 months ago

kstoneriv3 commented 10 months ago

Hi, I wanted to use parallel filtering and smoothing of LGSSM for the EM algorithm so I updated the parallel inference functions to the level of feature parity with serial filtering and smoothing.

During the implementation, I found a couple of bugs as well so this PR includes the bug fix as well. (They are joint sampling logic in inference.py and missing emission bias term in the log likelihood of parallel_inference.py).

~I thought this branch is almost ready for PR but it seems that I am having a large conflict due to the recent diagonal covariance PR. I will mark the PR as ready when the conflict is resolved.~ Now ready for review!

review-notebook-app[bot] commented 10 months ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB