probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

Implementation of rao-blackwellised particle filter (rbpf) and rbpf with optimal resampling. #323

Closed kostastsa closed 8 months ago

kostastsa commented 1 year ago
review-notebook-app[bot] commented 1 year ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

kostastsa commented 8 months ago

Hi @murphyk @slinderman , I was wondering if one would kindly take a look at this PR implementing the RBPF. Would you let me know if there is something that needs to be changed? I understand that implementing a SLDS model class goes beyond what was raised in issue https://github.com/probml/dynamax/issues/271 and might not be necessary, in which case I'd be happy to fix this and make a more minimal implementation.

murphyk commented 8 months ago

For the resampling step, maybe it would be worth making a jax version of https://github.com/nchopin/particles/blob/master/particles/resampling.py#L541?

murphyk commented 8 months ago

LGTM. We can modify resampling code later (if necessary).

slinderman commented 8 months ago

Thanks, @kostastsa. One of the big items on my wishlist is to have a solid implementation of SLDS model variants and inference algorithms in dynamax. This is a great start.

kostastsa commented 8 months ago

@murphyk Cool, I will check the multinomial resampling from Chopin and do a jax version. @slinderman That's great to hear! I would be super interested in being part of this effort. I can start working on this soon, since I will also need it for my own research.