probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

Parallel hmm posterior sample #342

Closed calebweinreb closed 2 weeks ago

calebweinreb commented 10 months ago

This PR implements a parallel version of HMM posterior sampling using associative scan (see https://github.com/probml/dynamax/issues/341). The scan elements $E_{ij}$ are vectors specifying a sample

z_j ~ p(z_j \mid z_i)

for each possible value of $z_i$. They can be thought of as functions $E : [1,...,n] \to [1,...,n]$ where the associative operator is function composition. This implementation passes the test written for serial sampling (which is commented out for some reason). It starts performing better than serial sampling when the sequence length exceeds a few thousand (I'm a little mystified as to why it takes so long for the crossover to happen).

from dynamax.hidden_markov_model.inference_test import random_hmm_args
from dynamax.hidden_markov_model import hmm_posterior_sample, parallel_hmm_posterior_sample
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import time

num_states = 2
num_iters = 5
timesteps = np.logspace(0,6,10).astype(int)
serial_times, parallel_times = [], []
for num_timesteps in timesteps:
    print(num_timesteps)
    serial_time, parallel_time = 0, 0
    for itr in range(num_iters+1):
        args = random_hmm_args(jr.PRNGKey(itr), num_timesteps, 5)

        t = time.time()
        hmm_posterior_sample(jr.PRNGKey(itr), *args)
        print('s', time.time()-t)
        if itr > 0: serial_time += time.time()-t

        t = time.time()
        parallel_hmm_posterior_sample(jr.PRNGKey(itr), *args)
        print('p', time.time()-t)
        if itr > 0: parallel_time += time.time()-t

    serial_times.append(serial_time/num_iters)
    parallel_times.append(parallel_time/num_iters)

plt.plot(timesteps, serial_times, label='serial')
plt.plot(timesteps, parallel_times, label='parallel')
plt.legend(loc='upper left')
plt.xscale('log')
plt.yscale('log')
plt.ylabel('Runtime (s)')
plt.xlabel('Sequence length')
plt.gcf().set_size_inches((3,2))
Screenshot 2023-09-05 at 11 22 51 AM
slinderman commented 10 months ago

Thanks @calebweinreb! To clarify, I would say that the associative operator takes in two sets of samples, $$z_s \sim p(zs \mid x{1:s}, z_{s+1}) $$ and $$z_t \sim p(zt \mid x{1:t}, z{t+1})$$ for all values of $z{s+1} \in [K]$ and $z_{t+1} \in [K]$.

Then, assuming $t > s$, the associative operator returns a sample $$z_s \sim p(zs \mid x{1:t}, z{t+1})$$ for all $z{t+1} \in [K]$.

The final message is a sample $z_T \sim p(zT \mid x{1:T})$, replicated $K$ times so that it is the same shape as the preceding messages.

The output of associative scan thus yields samples of $z{1:T} \sim p(z{1:T} \mid x_{1:T})$. The output shape is (T,K), but all columns are identical since they all started with the same final state. Thus, it suffices to take the first column of the output matrix.

gileshd commented 10 months ago

This looks really neat @calebweinreb!

One question about the timing results - is that on a cpu or gpu? I remember the behaviour being a bit different for different backends in the context of lgssm inference (for instance results from Adrien).

calebweinreb commented 10 months ago

Hi Scott, thanks for clarifying! I think we landed on a good way of articulating the algorithm over slack. I'll repost here in case others are interested:

calebweinreb commented 10 months ago

This looks really neat @calebweinreb!

One question about the timing results - is that on a cpu or gpu? I remember the behaviour being a bit different for different backends in the context of lgssm inference (for instance results from Adrien).

I ran the test on a GPU. I assume on a CPU, parallel would always do worse?