probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

Add support for parallel hmm sampler #341

Open calebweinreb opened 10 months ago

calebweinreb commented 10 months ago

Currently there are associative scan implementations for hmm smoothing and filtering, but not for sampling.

AdrienCorenflos commented 10 months ago

The HMM sampling can't really be done using prefix-sum. You would need to use a divide and conquer approach as in Section 3.2 in my paper https://arxiv.org/abs/2303.00301 (Note that it's for Gaussians but the general principle applies).

On Tue, 5 Sept 2023, 22:31 Caleb Weinreb, @.***> wrote:

Currently there are associative scan implementations for hmm smoothing and filtering, but not for sampling.

— Reply to this email directly, view it on GitHub https://github.com/probml/dynamax/issues/341, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEYGFZ2XVWPCYRF4TWE6GRDXY54X3ANCNFSM6AAAAAA4MH73AQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

slinderman commented 10 months ago

@AdrienCorenflos, I think @calebweinreb has actually come up with a clever way to implement the HMM (and LGSSM) sampling with associative_scan! Curious what you think of his implementation in this PR https://github.com/probml/dynamax/pull/342. You can see the corresponding parallel LGSSM sampler here.

AdrienCorenflos commented 10 months ago

The LGSSM is unsurprising (it's just a cumulative sum over pre-generated Gaussians). In fact I think I had it implemented some time ago.

I'm actually surprised by the HMM case though. I had given it some thought and decided it was not possible because I didn't see how selecting indices could be associative. I need to check how that is done here. Is there a small proof for it somewhere?

On Wed, 6 Sept 2023, 00:06 Scott Linderman, @.***> wrote:

@AdrienCorenflos https://github.com/AdrienCorenflos, I think @calebweinreb https://github.com/calebweinreb has actually come up with a clever way to implement the HMM (and LGSSM) sampling with associative_scan! Curious what you think of his implementation in this PR

342 https://github.com/probml/dynamax/pull/342. You can see the

corresponding parallel LGSSM sampler here https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/parallel_inference.py#L356 .

— Reply to this email directly, view it on GitHub https://github.com/probml/dynamax/issues/341#issuecomment-1707310436, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEYGFZ2VIY2Y7RHZOCTYJC3XY6H5HANCNFSM6AAAAAA4MH73AQ . You are receiving this because you were mentioned.Message ID: @.***>

calebweinreb commented 10 months ago

Hi Adrien,

The way I think about the algorithm, each associative scan element is a function $E{st}: [K] \to [K]$ where $E{st}(i) \mapsto z_s \sim P(zs \mid x{1:t-1}, z_t=i)$. The functions can be thought of as draws from a random process. The operator is function composition and it's associative because function composition is associative.

slinderman commented 10 months ago

I think the other necessary ingredient to the proof is that the function composition $E{su} = E{st} \circ E{tu}$ defined by $E{su}(i) = E{st}(E{tu}(i))$ ensures that $E_{su}(i) \mapsto z_s \sim P(zs \mid x{1:u-1}, zu=i)$. Moreover, $(E{su}(i), E_{tu}(i))$ are jointly distributed as $P(z_s, zt \mid x{1:u-1}, z_u=i)$. These follow from the fact that $zs$ is conditionally independent of $x{t:u-1}$ given $z_t$ in an HMM.

Finally, note that the base cases $E{s,s+1}$ are straightforward to sample given the filtering distributions. The final message is defined as $E{T,T+1} \triangleq E_T \mapsto z_T \sim p(zT \mid x{1:T})$.