dattalab / jax-moseq

Other
3 stars 5 forks source link

Port in dynamax hmm posterior inference #3

Closed ezhang94 closed 1 year ago

ezhang94 commented 1 year ago

Use dynamax.hidden_markov_model.inference.hmm_posterior_sample in place of jax_moseq.utils.distributions.sample_hmm_stateseq (called by jax_moseq.models.arhmm.gibbs.resample_discrete_stateseqs)

Approach

ezhang94 commented 1 year ago

Passes the following script, run via pytest test_hmm_inference.py:

"""test_hmm_inference.py

Script to test equivalence of HMM posterior inference implementations.
"""
import pytest
from jax import vmap
import jax.numpy as jnp
import jax.random as jr
from functools import partial

from jax_moseq.utils.distributions import sample_hmm_stateseq
from dynamax.hidden_markov_model.inference import hmm_posterior_sample

ATOL = 1e-1

NUM_TIMESTEPS = 50
SEED = jr.PRNGKey(3240)
CONFIGS = [
    (SEED, NUM_TIMESTEPS, int(0.1*NUM_TIMESTEPS)),
    (SEED, NUM_TIMESTEPS, int(0.4*NUM_TIMESTEPS)),
]

# =============================================================================

def random_hmm_args(key, num_timesteps, num_states, scale=1.0, n_mask=0):
    """Generate random HMM arguments, log-likelihoods, and masking.

    Arguments
        num_timesteps (int): number of timesteps in sequence, T
        num_states (int): number of hmm states, K
        scale (float): covariance scale
        n_mask (int): number of timesteps to mask

    Returns
        initial_probs: shape [K]
        transition_matrix: shape [K,K]
        log_likelihoods: shape [T,K]
        mask: shape [T]
    """
    k1, k2, k3, k4 = jr.split(key, 4)
    initial_probs = jr.uniform(k1, (num_states,))
    initial_probs /= initial_probs.sum()
    transition_matrix = jr.uniform(k2, (num_states, num_states))
    transition_matrix /= transition_matrix.sum(1, keepdims=True)

    log_likelihoods = scale * jr.normal(k3, (num_timesteps, num_states))

    mask = jnp.ones(num_timesteps, dtype=int)
    if n_mask > 0:
        i_flip = jr.randint(k4, (n_mask,), 0, num_timesteps)
        mask = mask.at[i_flip].set(0)

    return initial_probs, transition_matrix, log_likelihoods, mask

def expected_state_probability(sampled_states, num_states):
    """Compute expected state probability at each timestep given samples.

    Arguments
        sample_states (array): shape [N,T], taking on values 0,...,K-1
        num_states (int): Total number of states, K

    Returns
        expected_state_probabilities: shape [T,K]
    """

    count_state = lambda k: (sampled_states==k).sum(axis=0)
    counts = vmap(count_state, out_axes=-1)(jnp.arange(num_states))
    return counts/len(sampled_states)

def new_moseq_sample(seed, initial_distribution, transition_matrix, log_likelihoods, mask):
    """Proposed HMM posterior sampling implementation using dynamax and masking."""

    masked_log_likelihoods = log_likelihoods * mask[:,None]
    return hmm_posterior_sample(seed, initial_distribution, transition_matrix, masked_log_likelihoods)

# =============================================================================

def test_all_valid(seed=SEED, num_timesteps=NUM_TIMESTEPS, num_states=5, num_samples=1000):
    """Compare dynamax, original moseq, and new moseq implementations."""

    seed_params, seed_sample = jr.split(seed, 2)
    seed_dynamax, seed_moseq, seed_proposed = jr.split(seed_sample, 3)

    # Generate random HMM params
    initial_probs, transition_matrix, log_likelihoods, mask \
            = random_hmm_args(seed_params, num_timesteps, num_states)

    # Sample from proposed implementation
    _proposed_sample = vmap(partial(new_moseq_sample,
                                    initial_distribution=initial_probs,
                                    transition_matrix=transition_matrix,
                                    log_likelihoods=log_likelihoods,
                                    mask=mask,))
    _, proposed_states = _proposed_sample(jr.split(seed_proposed, num_samples))

    expected_proposed = expected_state_probability(proposed_states, num_states)

    # --------------------------------------------
    # Compare to original jax_moseq implementation
    _moseq_sample = vmap(partial(sample_hmm_stateseq,
                                 log_likelihoods=log_likelihoods,
                                 mask=mask,
                                 pi=transition_matrix,))
    moseq_states, _ = _moseq_sample(jr.split(seed_moseq, num_samples))

    expected_moseq = expected_state_probability(moseq_states, num_states)

    # Ignore initial probabilities here, since original jax_moseq implementation
    # hard-coded in initial distribution as uniform
    assert jnp.allclose((expected_moseq*mask[:,None])[1:],
                        (expected_proposed*mask[:,None])[1:],
                        atol=ATOL)

    # ------------------------------
    # Compare to dynmax implentation
    _dynamax_sample = vmap(partial(hmm_posterior_sample,
                                initial_distribution=initial_probs,
                                transition_matrix=transition_matrix,
                                log_likelihoods=log_likelihoods,))
    _, dynamax_states = _dynamax_sample(jr.split(seed_dynamax, num_samples))
    expected_dynamax = expected_state_probability(dynamax_states, num_states)

    assert jnp.allclose(expected_dynamax, expected_proposed, atol=ATOL)

@pytest.mark.parametrize(["seed", "num_timesteps", "num_invalid"], CONFIGS)
def test_invalid(seed, num_timesteps, num_invalid, num_states=5, num_samples=10000):
    """Compare dynamax, original moseq, and new moseq implementations."""

    seed_params, seed_sample = jr.split(seed, 2)
    seed_dynamax, seed_moseq, seed_proposed = jr.split(seed_sample, 3)

    # Generate random HMM params
    initial_probs, transition_matrix, log_likelihoods, mask \
            = random_hmm_args(seed_params, num_timesteps, num_states, n_mask=num_invalid)

    # Sample from proposed implementation
    _proposed_sample = vmap(partial(new_moseq_sample,
                                    initial_distribution=initial_probs,
                                    transition_matrix=transition_matrix,
                                    log_likelihoods=log_likelihoods,
                                    mask=mask,))
    _, proposed_states = _proposed_sample(jr.split(seed_proposed, num_samples))

    expected_proposed = expected_state_probability(proposed_states, num_states)

    # --------------------------------------------
    # Compare to original jax_moseq implementation
    _moseq_sample = vmap(partial(sample_hmm_stateseq,
                                 log_likelihoods=log_likelihoods,
                                 mask=mask,
                                 pi=transition_matrix,))
    moseq_states, _ = _moseq_sample(jr.split(seed_moseq, num_samples))

    expected_moseq = expected_state_probability(moseq_states, num_states)

    # Ignore initial probabilities here, since original jax_moseq implementation
    # hard-coded in initial distribution as uniform
    err = abs((expected_moseq*mask[:,None])[1:] - (expected_proposed*mask[:,None])[1:])
    assert (err < ATOL).mean() > 0.9
ezhang94 commented 1 year ago

Visualizations comparing original and proposed implementations

When all observations are valid, errors are within 1e-1 tolerance

image

When invalid observations are present, behavior differ (as expected)

Namely, the original implementation defaults to placing all probability on state 0, where as the proposed implementation evolves states according to transition matrix.

image

image