pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
72 stars 46 forks source link

Support HMM via marginalization of DiscreteMarkovChain #257

Closed ricardoV94 closed 4 months ago

ricardoV94 commented 8 months ago

The following example defines a 2-state HMM, with a 0.9 transition probability of staying in the same state, and a Normal emission centered around -1 for state 0 and 1 for state 1.

import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import pymc as pm
from pymc_experimental import MarginalModel
from pymc_experimental.distributions import DiscreteMarkovChain

with MarginalModel() as m:
    P = [[0.9, 0.1], [0.1, 0.9]]
    init_dist = pm.Categorical.dist(p=[1, 0])
    chain = DiscreteMarkovChain("chain", P=P, init_dist=init_dist, steps=10)
    emission = pm.Normal("emission", mu=chain * 2 - 1, sigma=0.5)

    m.marginalize([chain])

    with m:
        idata = pm.sample(100)

plt.plot(az.extract(idata)["emission"].values, color="k", alpha=0.03)
plt.yticks([-1, 1])
plt.ylabel("Emission")
plt.xlabel("Step");

image

Not implemented

Higher order lags and batch P matrices not supported due to complexity (and me not groking the exact API)

Closes #167

junpenglao commented 7 months ago

Is it using the Viterbi algorithm?

jessegrabowski commented 7 months ago

Currently just the forward algorithm* to compute the logp

*It's not pure forward because we are computing and storing p(data | state) for all data-state pairs outside the scan over state transition probabilities. We should be O(N^2*T) on compute, but we're not maximally efficient on memory.

If I understand well, viterbi just gives the most probable sequence of hidden states in a maximum likelihood setting? We should be able to back that out of the posterior pretty easily. You'll need to school me if I'm over simplifying.

junpenglao commented 7 months ago

If I understand well, viterbi just gives the most probable sequence of hidden states in a maximum likelihood setting? We should be able to back that out of the posterior pretty easily. You'll need to school me if I'm over simplifying.

yes Viterbi gives the posterior mode - but you are marginalizing the state to compute the likelihood here right?

ricardoV94 commented 7 months ago

but you are marginalizing the state to compute the likelihood here right?

Yes, but to be precise: to compute the logp of any dependent variables, which may be observed/unobserved or a mix.

ricardoV94 commented 7 months ago

Seems like our "clever" approach is not correct. We need to combine the emission probabilities as we compute the state probabilities iteratively. I thought we could factor them out but it doesn't seem to be the case.

jessegrabowski commented 7 months ago

Seems like our "clever" approach is not correct. We need to combine the emission probabilities as we compute the state probabilities iteratively. I thought we could factor them out but it doesn't seem to be the case.

I added the example from this youtube vid as a test case, so we can get to a solution.

I'm in the process of refactoring the logp function to compute alpha correctly, but it's typically a nested loop. Here's numpy code:

transition_probs = np.array([[0.5, 0.5], 
                             [0.3, 0.7]])
initial_probs = np.array([0.375, 0.625])

T = 3
data = [0, 0, 1]
log_alpha = np.zeros((T, 2))
x_dists = [stats.bernoulli(p=0.2), stats.bernoulli(p=0.6)]

def eval_logp(x, dists):
    return np.array([d.logpmf(x) for d in dists])

log_alpha[0, :] = np.log(initial_probs) + eval_logp(data[0], x_dists)
for t in range(1, T):
    obs = data[t]
    for s in range(transition_probs.shape[0]):
        step_log_prob = x_dists[s].logpmf(obs) + np.log(transition_probs[:, s]) + log_alpha[t-1, :]
        log_alpha[t, s] = logsumexp(step_log_prob)

I'm trying to think how we can vectorize the inner loop, open to suggestions.

Nvm figured this out, it looks like:

for t in range(1, T):
    obs = data[t]
    step_log_prob = np.log(transition_probs) + log_alpha[t-1, :, None]    
    log_alpha[t, :] = eval_logp(obs, x_dists) + logsumexp(step_log_prob, axis=0)
ricardoV94 commented 5 months ago

Categorical is one of the goals I have with #300

I think it's already working there, but I need to rebase and check once we merge this

ricardoV94 commented 5 months ago

The lags is a nice follow up. The current distribution doesn't have a clear API for lags and batch dims, which further stopped me from addressing it here

We just need to agree on this and then it should be straightforward to support both.

The design question is: how do you specify a markov chain with 2 lags and an extra batch dimension? Say something with shape (5, 100) with two lags but different transition matrixes for each of the five batched chains

jessegrabowski commented 5 months ago

Yeah good questions. You're right it's not clear. I guess the distribution has to store the n_lags variable and marginalize will have to ask it? Not sure. In general, the way lags are handled are not good -- at higher orders the transition matrix is almost certainly going to be sparse, so it makes more sense to make one huge k**n, k**n sparse matrix and store a hash table to index into it for specific lag tuples.

We could let the user declare the lagged matrices as a tensor (since it's a bit more natural IMO at least) then internally flatten it down and build the index table, then rebuild the tensors after sampling.

But this is all for another PR, I 100% agree.