Closed calebweinreb closed 2 weeks ago
Thanks @calebweinreb! To clarify, I would say that the associative operator takes in two sets of samples, $$z_s \sim p(zs \mid x{1:s}, z_{s+1}) $$ and $$z_t \sim p(zt \mid x{1:t}, z{t+1})$$ for all values of $z{s+1} \in [K]$ and $z_{t+1} \in [K]$.
Then, assuming $t > s$, the associative operator returns a sample $$z_s \sim p(zs \mid x{1:t}, z{t+1})$$ for all $z{t+1} \in [K]$.
The final message is a sample $z_T \sim p(zT \mid x{1:T})$, replicated $K$ times so that it is the same shape as the preceding messages.
The output of associative scan thus yields samples of $z{1:T} \sim p(z{1:T} \mid x_{1:T})$. The output shape is (T,K)
, but all columns are identical since they all started with the same final state. Thus, it suffices to take the first column of the output matrix.
This looks really neat @calebweinreb!
One question about the timing results - is that on a cpu or gpu? I remember the behaviour being a bit different for different backends in the context of lgssm inference (for instance results from Adrien).
Hi Scott, thanks for clarifying! I think we landed on a good way of articulating the algorithm over slack. I'll repost here in case others are interested:
This looks really neat @calebweinreb!
One question about the timing results - is that on a cpu or gpu? I remember the behaviour being a bit different for different backends in the context of lgssm inference (for instance results from Adrien).
I ran the test on a GPU. I assume on a CPU, parallel would always do worse?
This PR implements a parallel version of HMM posterior sampling using associative scan (see https://github.com/probml/dynamax/issues/341). The scan elements $E_{ij}$ are vectors specifying a sample
for each possible value of $z_i$. They can be thought of as functions $E : [1,...,n] \to [1,...,n]$ where the associative operator is function composition. This implementation passes the test written for serial sampling (which is commented out for some reason). It starts performing better than serial sampling when the sequence length exceeds a few thousand (I'm a little mystified as to why it takes so long for the crossover to happen).