Closed ezhang94 closed 1 year ago
Passes the following script, run via pytest test_hmm_inference.py
:
"""test_hmm_inference.py
Script to test equivalence of HMM posterior inference implementations.
"""
import pytest
from jax import vmap
import jax.numpy as jnp
import jax.random as jr
from functools import partial
from jax_moseq.utils.distributions import sample_hmm_stateseq
from dynamax.hidden_markov_model.inference import hmm_posterior_sample
ATOL = 1e-1
NUM_TIMESTEPS = 50
SEED = jr.PRNGKey(3240)
CONFIGS = [
(SEED, NUM_TIMESTEPS, int(0.1*NUM_TIMESTEPS)),
(SEED, NUM_TIMESTEPS, int(0.4*NUM_TIMESTEPS)),
]
# =============================================================================
def random_hmm_args(key, num_timesteps, num_states, scale=1.0, n_mask=0):
"""Generate random HMM arguments, log-likelihoods, and masking.
Arguments
num_timesteps (int): number of timesteps in sequence, T
num_states (int): number of hmm states, K
scale (float): covariance scale
n_mask (int): number of timesteps to mask
Returns
initial_probs: shape [K]
transition_matrix: shape [K,K]
log_likelihoods: shape [T,K]
mask: shape [T]
"""
k1, k2, k3, k4 = jr.split(key, 4)
initial_probs = jr.uniform(k1, (num_states,))
initial_probs /= initial_probs.sum()
transition_matrix = jr.uniform(k2, (num_states, num_states))
transition_matrix /= transition_matrix.sum(1, keepdims=True)
log_likelihoods = scale * jr.normal(k3, (num_timesteps, num_states))
mask = jnp.ones(num_timesteps, dtype=int)
if n_mask > 0:
i_flip = jr.randint(k4, (n_mask,), 0, num_timesteps)
mask = mask.at[i_flip].set(0)
return initial_probs, transition_matrix, log_likelihoods, mask
def expected_state_probability(sampled_states, num_states):
"""Compute expected state probability at each timestep given samples.
Arguments
sample_states (array): shape [N,T], taking on values 0,...,K-1
num_states (int): Total number of states, K
Returns
expected_state_probabilities: shape [T,K]
"""
count_state = lambda k: (sampled_states==k).sum(axis=0)
counts = vmap(count_state, out_axes=-1)(jnp.arange(num_states))
return counts/len(sampled_states)
def new_moseq_sample(seed, initial_distribution, transition_matrix, log_likelihoods, mask):
"""Proposed HMM posterior sampling implementation using dynamax and masking."""
masked_log_likelihoods = log_likelihoods * mask[:,None]
return hmm_posterior_sample(seed, initial_distribution, transition_matrix, masked_log_likelihoods)
# =============================================================================
def test_all_valid(seed=SEED, num_timesteps=NUM_TIMESTEPS, num_states=5, num_samples=1000):
"""Compare dynamax, original moseq, and new moseq implementations."""
seed_params, seed_sample = jr.split(seed, 2)
seed_dynamax, seed_moseq, seed_proposed = jr.split(seed_sample, 3)
# Generate random HMM params
initial_probs, transition_matrix, log_likelihoods, mask \
= random_hmm_args(seed_params, num_timesteps, num_states)
# Sample from proposed implementation
_proposed_sample = vmap(partial(new_moseq_sample,
initial_distribution=initial_probs,
transition_matrix=transition_matrix,
log_likelihoods=log_likelihoods,
mask=mask,))
_, proposed_states = _proposed_sample(jr.split(seed_proposed, num_samples))
expected_proposed = expected_state_probability(proposed_states, num_states)
# --------------------------------------------
# Compare to original jax_moseq implementation
_moseq_sample = vmap(partial(sample_hmm_stateseq,
log_likelihoods=log_likelihoods,
mask=mask,
pi=transition_matrix,))
moseq_states, _ = _moseq_sample(jr.split(seed_moseq, num_samples))
expected_moseq = expected_state_probability(moseq_states, num_states)
# Ignore initial probabilities here, since original jax_moseq implementation
# hard-coded in initial distribution as uniform
assert jnp.allclose((expected_moseq*mask[:,None])[1:],
(expected_proposed*mask[:,None])[1:],
atol=ATOL)
# ------------------------------
# Compare to dynmax implentation
_dynamax_sample = vmap(partial(hmm_posterior_sample,
initial_distribution=initial_probs,
transition_matrix=transition_matrix,
log_likelihoods=log_likelihoods,))
_, dynamax_states = _dynamax_sample(jr.split(seed_dynamax, num_samples))
expected_dynamax = expected_state_probability(dynamax_states, num_states)
assert jnp.allclose(expected_dynamax, expected_proposed, atol=ATOL)
@pytest.mark.parametrize(["seed", "num_timesteps", "num_invalid"], CONFIGS)
def test_invalid(seed, num_timesteps, num_invalid, num_states=5, num_samples=10000):
"""Compare dynamax, original moseq, and new moseq implementations."""
seed_params, seed_sample = jr.split(seed, 2)
seed_dynamax, seed_moseq, seed_proposed = jr.split(seed_sample, 3)
# Generate random HMM params
initial_probs, transition_matrix, log_likelihoods, mask \
= random_hmm_args(seed_params, num_timesteps, num_states, n_mask=num_invalid)
# Sample from proposed implementation
_proposed_sample = vmap(partial(new_moseq_sample,
initial_distribution=initial_probs,
transition_matrix=transition_matrix,
log_likelihoods=log_likelihoods,
mask=mask,))
_, proposed_states = _proposed_sample(jr.split(seed_proposed, num_samples))
expected_proposed = expected_state_probability(proposed_states, num_states)
# --------------------------------------------
# Compare to original jax_moseq implementation
_moseq_sample = vmap(partial(sample_hmm_stateseq,
log_likelihoods=log_likelihoods,
mask=mask,
pi=transition_matrix,))
moseq_states, _ = _moseq_sample(jr.split(seed_moseq, num_samples))
expected_moseq = expected_state_probability(moseq_states, num_states)
# Ignore initial probabilities here, since original jax_moseq implementation
# hard-coded in initial distribution as uniform
err = abs((expected_moseq*mask[:,None])[1:] - (expected_proposed*mask[:,None])[1:])
assert (err < ATOL).mean() > 0.9
Namely, the original implementation defaults to placing all probability on state 0, where as the proposed implementation evolves states according to transition matrix.
Use
dynamax.hidden_markov_model.inference.hmm_posterior_sample
in place ofjax_moseq.utils.distributions.sample_hmm_stateseq
(called byjax_moseq.models.arhmm.gibbs.resample_discrete_stateseqs
)Approach
Masking can be handled by pre-applying the mask directly to the log-likelihoods. For example, suppose we have an emission that we would like to ignore, i.e.
mask_t=0
. Then, in this proposal,ll_t <- (log_likelihoods * mask)[t] = 0
. So, (see comments showing equivalence),Evaluate that the two implementations produce the same states in expectation