Closed ricardoV94 closed 4 months ago
Is it using the Viterbi algorithm?
Currently just the forward algorithm* to compute the logp
*It's not pure forward because we are computing and storing p(data | state) for all data-state pairs outside the scan over state transition probabilities. We should be O(N^2*T) on compute, but we're not maximally efficient on memory.
If I understand well, viterbi just gives the most probable sequence of hidden states in a maximum likelihood setting? We should be able to back that out of the posterior pretty easily. You'll need to school me if I'm over simplifying.
If I understand well, viterbi just gives the most probable sequence of hidden states in a maximum likelihood setting? We should be able to back that out of the posterior pretty easily. You'll need to school me if I'm over simplifying.
yes Viterbi gives the posterior mode - but you are marginalizing the state to compute the likelihood here right?
but you are marginalizing the state to compute the likelihood here right?
Yes, but to be precise: to compute the logp of any dependent variables, which may be observed/unobserved or a mix.
Seems like our "clever" approach is not correct. We need to combine the emission probabilities as we compute the state probabilities iteratively. I thought we could factor them out but it doesn't seem to be the case.
Seems like our "clever" approach is not correct. We need to combine the emission probabilities as we compute the state probabilities iteratively. I thought we could factor them out but it doesn't seem to be the case.
I added the example from this youtube vid as a test case, so we can get to a solution.
I'm in the process of refactoring the logp function to compute alpha correctly, but it's typically a nested loop. Here's numpy code:
transition_probs = np.array([[0.5, 0.5],
[0.3, 0.7]])
initial_probs = np.array([0.375, 0.625])
T = 3
data = [0, 0, 1]
log_alpha = np.zeros((T, 2))
x_dists = [stats.bernoulli(p=0.2), stats.bernoulli(p=0.6)]
def eval_logp(x, dists):
return np.array([d.logpmf(x) for d in dists])
log_alpha[0, :] = np.log(initial_probs) + eval_logp(data[0], x_dists)
for t in range(1, T):
obs = data[t]
for s in range(transition_probs.shape[0]):
step_log_prob = x_dists[s].logpmf(obs) + np.log(transition_probs[:, s]) + log_alpha[t-1, :]
log_alpha[t, s] = logsumexp(step_log_prob)
I'm trying to think how we can vectorize the inner loop, open to suggestions.
Nvm figured this out, it looks like:
for t in range(1, T):
obs = data[t]
step_log_prob = np.log(transition_probs) + log_alpha[t-1, :, None]
log_alpha[t, :] = eval_logp(obs, x_dists) + logsumexp(step_log_prob, axis=0)
Categorical is one of the goals I have with #300
I think it's already working there, but I need to rebase and check once we merge this
The lags is a nice follow up. The current distribution doesn't have a clear API for lags and batch dims, which further stopped me from addressing it here
We just need to agree on this and then it should be straightforward to support both.
The design question is: how do you specify a markov chain with 2 lags and an extra batch dimension? Say something with shape (5, 100) with two lags but different transition matrixes for each of the five batched chains
Yeah good questions. You're right it's not clear. I guess the distribution has to store the n_lags
variable and marginalize
will have to ask it? Not sure. In general, the way lags are handled are not good -- at higher orders the transition matrix is almost certainly going to be sparse, so it makes more sense to make one huge k**n, k**n
sparse matrix and store a hash table to index into it for specific lag tuples.
We could let the user declare the lagged matrices as a tensor (since it's a bit more natural IMO at least) then internally flatten it down and build the index table, then rebuild the tensors after sampling.
But this is all for another PR, I 100% agree.
The following example defines a 2-state HMM, with a 0.9 transition probability of staying in the same state, and a Normal emission centered around -1 for state 0 and 1 for state 1.
Not implemented
Higher order lags and batch P matrices not supported due to complexity (and me not groking the exact API)
Closes #167