probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
636 stars 69 forks source link

implement rao-blackwellised particle filtering using dynamax and blackjax #271

Open murphyk opened 1 year ago

murphyk commented 1 year ago

combine Kalman Filter from dynamax and sequential monte carlo from Blackjax to implement RBPF for switching linear Gaussian SSM. For details, see sec 13.4.1 of https://probml.github.io/pml-book/book2.html

murphyk commented 1 year ago

See also https://github.com/probml/JSL/blob/main/jsl/lds/mixture_kalman_filter.py

kostastsa commented 1 year ago

Is this still of interest? If so I can have a go at it. Also I was thinking that maybe this could be part of a larger submodule of dynamax with other SLDS inference and learning functions?

murphyk commented 1 year ago

yes, this is still of interest. You may want to lookinto the SMC support in blackjax, and see if we can pass a dynamax model to it for vanilla PF, then use the marginalization trick from dynamax.kalman_filter to do RBPF.

kostastsa commented 1 year ago

I have looked into the SMC support of blackjax and as far as I can tell blackjax.smc.base.py is an implementation of waste-free SMC, for approximating a sequence of distributions. Correct me if I am missing something, but in the RBPF we are sampling a discrete Markov chain (so no need for the MCMC steps of the waste-free sampler) ; wouldn't it be sufficient to use our own custom proposals? Also, a link to dynamax.kalman_filter would be much appreciated. I cannot find the file in the current version of the repo.

murphyk commented 1 year ago

It's true that for RBPF we can just use a simple discrete proposal, so blackjax may be irrelevant.

Here is the KF code: https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/inference.py#L370

kostastsa commented 1 year ago

Sounds good! I'll get started on it then.

kostastsa commented 1 year ago

I have a question regarding implementation of the optimal resampling method:

Is there an efficient way to determine the threshold $c$ of the optimal resampling method from Fearnhead & Clifford?

The best that I have come up with is the naive way, i.e., to try each value of $L$ of the number of resampled particles and check if $c$ satisfies the equation for this value. Is there some other (potentially approximate) algorithm to do so more efficiently? Once this has been addressed I can submit a PR for this feature. Thanks!

Sibgatulin commented 4 months ago

Thanks for your effort @kostastsa! I am still wrapping my head around application of RBPF to SLDSs, but so far I've been able to filter the state in my model quite successfully. I am now wondering if I understand it correctly that the current implementation of SLDS does not yet support parameter estimation (e.g. dynamic and emission covariances). And if so, what would be the most reasonable way to extend the model? The PML2 book hints at SMC for parameter inference, but I, frankly, cannot quite see the exact way to combine the tempering techniques from section 13.6 with this RBPF approach.

kostastsa commented 4 months ago

Hi @Sibgatulin, thank you for your interest and for your question. You understand correctly, the current SLDS class does not support parameter estimation at the moment, although it's definitely something we're interested in adding. I think that probably the most popular way of doing parameter estimation in this model is using the EM algorithm. See for example @murphyk's books (or this technical report which is the most detailed description that I have found). The SMC approach would be useful for computing the posterior of the parameters. It is also a matter of preference, but I would begin with an implementation of EM, rather than SMC. Hope this addresses some of your questions! Best, Kostas