probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

My non-stationary Markov toy example isn't learning #363

Closed DBraun closed 3 weeks ago

DBraun commented 3 weeks ago

Sorry to ask a question that probably isn't bug related, but I didn't think I'd get help elsewhere (StackOverflow etc.) It's minorly related to https://github.com/probml/dynamax/issues/310.

I have runnable code of the following toy problem, but it doesn't learn.

Here's a summary of the code.

CyclingHMMInitialState subclasses StandardHMMInitialState. It implements distribution by returning a tfd.JointDistributionSequential of two values.

    def distribution(self, params, inputs=None):
        return tfd.JointDistributionSequential([
            tfd.Deterministic(0),  # Always start at the 0th transition matrix in the cycle
            tfd.Categorical(probs=params.probs),
        ])

The first is the integer tracking which of the $M$ transition matrices we're using. Since it's the initial state, we start with tfd.Deterministic(0). The second value is tfd.Categorical(probs=params.probs) based on the HMM casino tutorial.

CyclingHMMTransitions subclasses StandardHMMTransitions. We have to implement concentration and transition_matrix to be 3D instead of 2D. We also implement distribution to increment the choice of transition matrix mod $M$. Note that self.cycle_dim is $M$.

    def distribution(self, params, state, inputs=None):
        cycle_index, state_index = state
        cycle_index = (cycle_index + 1) % self.cycle_dim
        return tfd.JointDistributionSequential([
            tfd.Deterministic(cycle_index),
            tfd.Categorical(probs=params.transition_matrix[cycle_index.astype(jnp.uint32), state_index])
        ])

The implementation of collect_suff_stats is interesting and possibly wrong.

    def collect_suff_stats(self, params, posterior, inputs=None):
        # return posterior.trans_probs
        num_timesteps = posterior.trans_probs.shape[0]
        trans_probs = jnp.stack([
            posterior.trans_probs[jnp.arange(i, num_timesteps, step=self.cycle_dim)].sum(axis=0)  # todo:
            for i in range(self.cycle_dim)])
        return trans_probs

Note that posterior.trans_probs, which the superclass's method would return, is shaped $T \times K \times K$. I think we want this function to return a matrix that's shaped $M \times K \times K$. If we consider what should be at [0, ..., ...] in this output matrix, it's based on the transitions at timesteps $t$ where $mod(t, M)=0$. The output at [1, ..., ...] is based on all the transitions where $mod(t, M)=1$, and so on.

CyclingHMMEmissions subclasses CategoricalHMMEmissions. Its distribution is fully deterministic.

    def distribution(self, params, state, inputs=None):
        cycle_index, state_index = state
        return tfd.JointDistributionSequential([
            tfd.Deterministic([cycle_index]),
            tfd.Deterministic([state_index])
        ])

Therefore it has no learnable parameters, and so log_prior just returns 0. Also it implements this:

    def _compute_conditional_logliks(self, params, emissions, inputs=None):
        a = emissions[1].reshape((-1,))
        a = jnp.round(a).astype(jnp.uint32)
        a = one_hot(a, num_classes=self.num_states)
        return jnp.where(a, jnp.zeros_like(a), jnp.full_like(a, fill_value=-jnp.inf))

This output is shaped $T \times K$. Based on the emissions, the log-likelihood is either 0 (log(1)) or -jnp.inf (limit of log(0)). I'm not sure about this function, but I also tried just returning return jnp.zeros(shape=(num_timesteps, self.num_states)) where num_timesteps = emissions[0].shape[0]. And that dudn't lead to successful learning.

Last, CyclingCategoricalHMM subclasses HMM. Its implementation of e_step uses a transition_fn which selects the $mod(t, M)$-th transition matrix.

Here's the full code:

import numpy as np
from jaxtyping import Array, Float

from functools import partial
from typing import NamedTuple, Union, Tuple, Optional
import jax
import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from jax.nn import one_hot
from jaxtyping import Array, Float
from dynamax.parameters import ParameterProperties, ParameterSet, PropertySet
from dynamax.utils.utils import pytree_sum
from dynamax.types import Scalar

import matplotlib.pyplot as plt

from dynamax.hidden_markov_model.models.abstractions import HMM
# from dynamax.hidden_markov_model.models.abstractions import HMMInitialState, HMMEmissions, HMMTransitions
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.categorical_hmm import CategoricalHMMEmissions, StandardHMMTransitions
# from dynamax.hidden_markov_model.models.categorical_hmm import ParamsCategoricalHMM, ParamsCategoricalHMMEmissions, ParamsStandardHMMTransitions
from dynamax.hidden_markov_model.inference import hmm_two_filter_smoother

class CyclingHMMInitialState(StandardHMMInitialState):
    """Abstract class for HMM initial distributions.
    """
    def __init__(self,
                 num_states,
                 initial_probs_concentration=1.1):
        """
        Args:
            initial_probabilities[k]: prob(hidden(1)=k)
        """
        self.num_states = num_states
        self.initial_probs_concentration = initial_probs_concentration * jnp.ones(num_states)

    def distribution(self, params, inputs=None):
        return tfd.JointDistributionSequential([
            tfd.Deterministic(0),  # Always start at the 0th transition matrix in the cycle
            tfd.Categorical(probs=params.probs),
        ])

class ParamsCyclingHMMTransitions(NamedTuple):
    transition_matrix: Union[Float[Array, "cycle_dim state_dim state_dim"], ParameterProperties]

class CyclingHMMTransitions(StandardHMMTransitions):
    r"""Standard model for HMM transitions.

    We place a Dirichlet prior over the rows of the transition matrix $A$,

    $$A_k \sim \mathrm{Dir}(\beta 1_K + \kappa e_k)$$

    where

    * $1_K$ denotes a length-$K$ vector of ones,
    * $e_k$ denotes the one-hot vector with a 1 in the $k$-th position,
    * $\beta \in \mathbb{R}_+$ is the concentration, and
    * $\kappa \in \mathbb{R}_+$ is the `stickiness`.

    """
    def __init__(self, cycle_dim, num_states, concentration=1.1, stickiness=0.0):
        """
        Args:
            transition_matrix[j,k]: prob(hidden(t) = k | hidden(t-1)j)
        """
        self.cycle_dim = cycle_dim
        self.num_states = num_states
        concentration = \
            concentration * jnp.ones((num_states, num_states)) + \
            stickiness * jnp.eye(num_states)
        concentration = jnp.tile(jnp.expand_dims(concentration, axis=0), reps=(self.cycle_dim, 1, 1))  # todo:
        self.concentration = concentration

    def distribution(self, params, state, inputs=None):
        cycle_index, state_index = state
        cycle_index = (cycle_index + 1) % self.cycle_dim
        return tfd.JointDistributionSequential([
            tfd.Deterministic(cycle_index),
            tfd.Categorical(probs=params.transition_matrix[cycle_index.astype(jnp.uint32), state_index])
        ])

    def initialize(self, key=None, method="prior", transition_matrix=None):
        """Initialize the model parameters and their corresponding properties.

        Args:
            key (_type_, optional): _description_. Defaults to None.
            method (str, optional): _description_. Defaults to "prior".
            transition_matrix (_type_, optional): _description_. Defaults to None.

        Returns:
            _type_: _description_
        """
        if transition_matrix is None:
            this_key, key = jr.split(key)
            transition_matrix = tfd.Dirichlet(self.concentration).sample(seed=this_key)

        # Package the results into dictionaries
        params = ParamsCyclingHMMTransitions(transition_matrix=transition_matrix)
        props = ParamsCyclingHMMTransitions(transition_matrix=ParameterProperties(constrainer=tfb.SoftmaxCentered()))
        return params, props

    def collect_suff_stats(self, params, posterior, inputs=None):
        # return posterior.trans_probs
        num_timesteps = posterior.trans_probs.shape[0]
        trans_probs = jnp.stack([
            posterior.trans_probs[jnp.arange(i, num_timesteps, step=self.cycle_dim)].sum(axis=0)  # todo:
            for i in range(self.cycle_dim)])
        return trans_probs

class ParamsCyclingHMMEmissions(NamedTuple):
    pass

class ParamsCyclingHMM(NamedTuple):
    initial: ParamsStandardHMMInitialState
    transitions: ParamsCyclingHMMTransitions
    emissions: None

class CyclingHMMEmissions(CategoricalHMMEmissions):

    def __init__(self,
                 num_states,
                 emission_dim):
        self.num_states = num_states
        self.emission_dim = emission_dim

    @property
    def emission_shape(self):
        return [(self.emission_dim,), (self.emission_dim,)]  # todo:
        # return (2, self.emission_dim,)
        # return ((self.emission_dim,), (self.emission_dim,))
        # return (self.emission_dim,)

    def distribution(self, params, state, inputs=None):
        cycle_index, state_index = state
        return tfd.JointDistributionSequential([
            tfd.Deterministic([cycle_index]),
            tfd.Deterministic([state_index])
        ])
        # return tfd.Deterministic(state_index)  # todo:
        # return tfd.Deterministic([state_index])  # todo:
        # return tfd.Independent(
        #   tfd.Deterministic([state_index]),
        #   reinterpreted_batch_ndims=0)

    def log_prior(self, params):
        # todo: the emissions are fully deterministic, so the log prior is 0, right?
        return 0

    def _compute_conditional_logliks(self, params, emissions, inputs=None):
        a = emissions[1].reshape((-1,))
        a = jnp.round(a).astype(jnp.uint32)
        a = one_hot(a, num_classes=self.num_states)
        return jnp.where(a, jnp.zeros_like(a), jnp.full_like(a, fill_value=-jnp.inf))

    def initialize(self, key=jr.PRNGKey(0), method="prior"):
        """Initialize the model parameters and their corresponding properties.

        You can either specify parameters manually via the keyword arguments, or you can have
        them set automatically. If any parameters are not specified, you must supply a PRNGKey.
        Parameters will then be sampled from the prior (if `method==prior`).

        Note: in the future we may support more initialization schemes, like K-Means.

        Args:
            key (PRNGKey, optional): random number generator for unspecified parameters. Must not be None if there are any unspecified parameters. Defaults to jr.PRNGKey(0).
            method (str, optional): method for initializing unspecified parameters. Currently, only "prior" is allowed. Defaults to "prior".

        Returns:
            params: nested dataclasses of arrays containing model parameters.
            props: a nested dictionary of ParameterProperties to specify parameter constraints and whether or not they should be trained.
        """

        # Add parameters to the dictionary
        params = ParamsCyclingHMMEmissions()
        props = ParamsCyclingHMMEmissions()
        return params, props

    def collect_suff_stats(self, params, posterior, emissions, inputs=None):
        # todo: the emissions are fully deterministic, so return empty dict?
        return dict()

    def m_step(self, params, props, batch_stats, m_step_state):
        # todo: the emissions are fully deterministic, so nothing to maximize?
        return params, m_step_state

class CyclingCategoricalHMM(HMM):
    r"""An HMM with conditionally independent categorical emissions.

    Let $y_t \in \{1,\ldots,C\}^N$ denote a vector of $N$ conditionally independent
    categorical emissions from $C$ classes at time $t$. In this model,the emission
    distribution is,

    $$p(y_t \mid z_t, \theta) = \prod_{n=1}^N \mathrm{Cat}(y_{tn} \mid \theta_{z_t,n})$$
    $$p(\theta) = \prod_{k=1}^K \prod_{n=1}^N \mathrm{Dir}(\theta_{k,n}; \gamma 1_C)$$

    with $\theta_{k,n} \in \Delta_C$ for $k=1,\ldots,K$ and $n=1,\ldots,N$ are the
    *emission probabilities* and $\gamma$ is their prior concentration.

    :param cycle_dim: number of $K-K$ transition matrices to cycle through
    :param num_states: number of discrete states $K$
    :param emission_dim: number of conditionally independent emissions $N$
    :param initial_probs_concentration: $\alpha$
    :param transition_matrix_concentration: $\beta$
    :param transition_matrix_stickiness: optional hyperparameter to boost the concentration on the diagonal of the transition matrix.

    """

    def __init__(self,
                 cycle_dim: int,
                 num_states: int,
                 emission_dim: int,
                 initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]] = 1.1,
                 transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]] = 1.1,
                 transition_matrix_stickiness: Scalar = 0.0):
        self.cycle_dim = cycle_dim
        self.emission_dim = emission_dim
        initial_component = CyclingHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration)
        transition_component = CyclingHMMTransitions(cycle_dim, num_states,
                                                     concentration=transition_matrix_concentration,
                                                     stickiness=transition_matrix_stickiness)
        emission_component = CyclingHMMEmissions(num_states, emission_dim)
        super().__init__(num_states, initial_component, transition_component, emission_component)

    def initialize(self,
                   key: jr.PRNGKey = jr.PRNGKey(0),
                   method: str = "prior",
                   initial_probs: Optional[Float[Array, "num_states"]] = None,
                   transition_matrix: Optional[Float[Array, "cycle_dim num_states num_states"]] = None,
                   ) -> Tuple[ParameterSet, PropertySet]:
        """Initialize the model parameters and their corresponding properties.

        You can either specify parameters manually via the keyword arguments, or you can have
        them set automatically. If any parameters are not specified, you must supply a PRNGKey.
        Parameters will then be sampled from the prior (if `method==prior`).

        Note: in the future we may support more initialization schemes, like K-Means.

        Args:
            key (PRNGKey, optional): random number generator for unspecified parameters. Must not be None if there are any unspecified parameters. Defaults to None.
            method (str, optional): method for initializing unspecified parameters. Currently, only "prior" is allowed. Defaults to "prior".
            initial_probs (array, optional): manually specified initial state probabilities. Defaults to None.
            transition_matrix (array, optional): manually specified transition matrix. Defaults to None.

        Returns:
            Model parameters and their properties.
        """
        key1, key2, key3 = jr.split(key, 3)
        params, props = dict(), dict()
        params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method,
                                                                                initial_probs=initial_probs)
        params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method,
                                                                                           transition_matrix=transition_matrix)
        params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method)
        return ParamsCyclingHMM(**params), ParamsCyclingHMM(**props)

    def e_step(self, params, emissions, inputs=None):
        """The E-step computes expected sufficient statistics under the
        posterior. In the generic case, we simply return the posterior itself.
        """
        initial_distribution, transition_matrix, log_likelihoods = self._inference_args(params, emissions, inputs)
        transition_fn = lambda index: transition_matrix[index % self.cycle_dim]
        posterior = hmm_two_filter_smoother(initial_distribution=initial_distribution, log_likelihoods=log_likelihoods,
                                            transition_matrix=None,  # None because we use `transition_fn`
                                            transition_fn=transition_fn)

        initial_stats = self.initial_component.collect_suff_stats(params.initial, posterior, inputs)
        transition_stats = self.transition_component.collect_suff_stats(params.transitions, posterior, inputs)
        emission_stats = self.emission_component.collect_suff_stats(params.emissions, posterior, emissions, inputs)
        return (initial_stats, transition_stats, emission_stats), posterior.marginal_loglik

def main():

    key = jr.PRNGKey(0)
    key, subkey = jr.split(key)

    # Define the model parameters
    cycle_dim = 3
    num_emissions = 1  # Only one emission at a time
    num_observable_states = 2

    # Initialize the parameters for the cycling model
    initial_probs = jnp.full((num_observable_states), 1.0/num_observable_states)
    # transition_matrix = jnp.full((cycle_dim, num_observable_states, num_observable_states), 1.0)

    # initial_probs = None
    transition_matrix = None
    key, subkey = jr.split(key)
    transition_matrix = jax.nn.softmax(.2*jr.normal(subkey, shape=(cycle_dim, num_observable_states, num_observable_states)), axis=-1)

    # Construct the CyclingCategoricalHMM
    hmm = CyclingCategoricalHMM(cycle_dim, num_observable_states, num_emissions,
                                # initial_probs_concentration=1.1,  # # todo: default 1.1
                                # transition_matrix_concentration=.5,  # todo: default 1.1
                                # transition_matrix_stickiness=0.0  # todo: default 0.0
                                )

    # Initialize the parameters struct with known values
    key, subkey = jr.split(key)
    params, _ = hmm.initialize(subkey,
                               initial_probs=initial_probs,
                               transition_matrix=transition_matrix,
                               )

    # Generate synthetic data
    num_batches = 100
    num_timesteps = 2000

    key, subkey = jr.split(key)
    batch_states, batch_emissions = \
        jax.vmap(partial(hmm.sample, params, num_timesteps=num_timesteps))(
            jr.split(subkey, num_batches))

    print(f"batch_states.shape:    {batch_states[1].shape}")
    print(f"batch_emissions.shape: {batch_emissions[1].shape}")

    # note that batch_states[0] and batch_emissions[0] correspond
    # to the cycle_dim index, i.e. $$timestep % cycle_dim$$

    def print_params(params):
        jnp.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
        print("Initial probs:")
        print(params.initial.probs)
        print("Transition matrices:")
        for i in range(cycle_dim):
            print(f"Transition matrix {i}:")
            print(params.transitions.transition_matrix[i])

    print_params(params)

    # Train the model using EM
    num_iters = 100
    key, subkey = jr.split(key)
    em_params, em_param_props = hmm.initialize(subkey)
    em_params, log_probs = hmm.fit_em(em_params,
                                      em_param_props,
                                      batch_emissions,
                                      num_iters=num_iters)

    # Compute the "losses" from EM
    em_losses = -log_probs / batch_emissions[1].size

    # Compute the loss if you used the parameters that generated the data
    true_loss = jax.vmap(partial(hmm.marginal_log_prob, params))(batch_emissions).sum()
    true_loss += hmm.log_prior(params)
    true_loss = -true_loss / batch_emissions[1].size

    # Plot the learning curve
    plt.plot(em_losses, label="EM")
    plt.axhline(true_loss, color='k', linestyle=':', label="True Params")
    plt.legend()
    plt.xlim(-10, num_iters)
    plt.xlabel("epoch")
    plt.ylabel("loss")
    _ = plt.title("Learning Curve")

    print('Learned parameters:')

    print_params(em_params)

    plt.show()

def simple_main():
    key = jr.PRNGKey(0)
    key, subkey = jr.split(key)

    cycle_dim = 2  # Simplified to 2 for testing
    num_states = 2
    num_emissions = 1
    num_timesteps = 10  # Simplified for testing

    initial_probs = jax.nn.softmax(.2*jr.normal(subkey, (num_states,)))
    print("Initial Probabilities:", initial_probs)
    print("Sum of Initial Probabilities:", jnp.sum(initial_probs))

    key, subkey = jr.split(key)
    transition_matrix = jax.nn.softmax(.2*jr.normal(subkey, (cycle_dim, num_states, num_states)), axis=-1)
    print("Transition Matrices:")
    for i in range(cycle_dim):
        print(f"Transition matrix {i}:\n{transition_matrix[i]}")

    hmm = CyclingCategoricalHMM(cycle_dim, num_states, num_emissions)

    key, subkey = jr.split(key)
    params, _ = hmm.initialize(subkey, initial_probs=initial_probs, transition_matrix=transition_matrix)

    key, subkey = jr.split(key)
    states, emissions = hmm.sample(params, num_timesteps=num_timesteps, key=subkey)
    print("Sampled States:\n", states)
    print("Sampled Emissions:\n", emissions)

    (initial_stats, transition_stats, emission_stats), log_likelihood = hmm.e_step(params, emissions)

    print("Initial Stats:", initial_stats)
    print("Transition Stats:", transition_stats)
    print("Emission Stats:", emission_stats)
    print("Log Likelihood:", log_likelihood)

if __name__ == '__main__':
    main()
    # simple_main()

The output: image

batch_states.shape:    (100, 2000)
batch_emissions.shape: (100, 2000, 1)
Initial probs:
[0.500 0.500]
Transition matrices:
Transition matrix 0:
[[0.558 0.442]
 [0.375 0.625]]
Transition matrix 1:
[[0.454 0.546]
 [0.382 0.618]]
Transition matrix 2:
[[0.429 0.571]
 [0.473 0.527]]
Learned parameters:
Initial probs:
[0.540 0.460]
Transition matrices:
Transition matrix 0:
[[0.455 0.545]
 [0.382 0.618]]
Transition matrix 1:
[[0.425 0.575]
 [0.471 0.529]]
Transition matrix 2:
[[0.558 0.442]
 [0.375 0.625]]
DBraun commented 3 weeks ago

I edited _compute_conditional_logliks in the post above. This changes the output graph and results, but it still doesn't seem to be learning.

DBraun commented 3 weeks ago

I found the bug. This is the correct version of CyclingHMMTransitions's distribution:

    def distribution(self, params, state, inputs=None):
        cycle_index, state_index = state
        return tfd.JointDistributionSequential([
            tfd.Deterministic((cycle_index + 1) % self.cycle_dim),
            tfd.Categorical(probs=params.transition_matrix[cycle_index.astype(jnp.uint32), state_index])
        ])

It is important to not increment the cycle_index that is being used to lookup into params.transition_matrix! Before this fix, I was looking at the resulting trained matrices and noticed that they were trained correctly, but their indices were off by one: trained matrix 1 was supposed to have the values of train matrix 2 and so on. Anyway, here is the correct code which learns quickly!

import numpy as np
from jaxtyping import Array, Float

from functools import partial
from typing import NamedTuple, Union, Tuple, Optional
import jax
import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from jax.nn import one_hot
from jaxtyping import Array, Float
import optax
from dynamax.parameters import ParameterProperties, ParameterSet, PropertySet
from dynamax.utils.utils import pytree_sum
from dynamax.types import Scalar

import matplotlib.pyplot as plt

from dynamax.hidden_markov_model.models.abstractions import HMM
# from dynamax.hidden_markov_model.models.abstractions import HMMInitialState, HMMEmissions, HMMTransitions
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.categorical_hmm import CategoricalHMMEmissions, StandardHMMTransitions
# from dynamax.hidden_markov_model.models.categorical_hmm import ParamsCategoricalHMM, ParamsCategoricalHMMEmissions, ParamsStandardHMMTransitions
from dynamax.hidden_markov_model.inference import hmm_two_filter_smoother

from dynamax.parameters import to_unconstrained, from_unconstrained
from dynamax.parameters import ParameterSet, PropertySet
from dynamax.types import PRNGKey, Scalar
from dynamax.utils.optimize import run_sgd
from dynamax.utils.utils import ensure_array_has_batch_dim

class CyclingHMMInitialState(StandardHMMInitialState):
    """Abstract class for HMM initial distributions.
    """
    def __init__(self,
                 num_states,
                 initial_probs_concentration=1.1):
        """
        Args:
            initial_probabilities[k]: prob(hidden(1)=k)
        """
        self.num_states = num_states
        self.initial_probs_concentration = initial_probs_concentration * jnp.ones(num_states)

    def distribution(self, params, inputs=None):
        return tfd.JointDistributionSequential([
            tfd.Deterministic(0),  # Always start at the 0th transition matrix in the cycle
            tfd.Categorical(probs=params.probs),
        ])

class ParamsCyclingHMMTransitions(NamedTuple):
    transition_matrix: Union[Float[Array, "cycle_dim state_dim state_dim"], ParameterProperties]

class CyclingHMMTransitions(StandardHMMTransitions):
    r"""Standard model for HMM transitions.

    We place a Dirichlet prior over the rows of the transition matrix $A$,

    $$A_k \sim \mathrm{Dir}(\beta 1_K + \kappa e_k)$$

    where

    * $1_K$ denotes a length-$K$ vector of ones,
    * $e_k$ denotes the one-hot vector with a 1 in the $k$-th position,
    * $\beta \in \mathbb{R}_+$ is the concentration, and
    * $\kappa \in \mathbb{R}_+$ is the `stickiness`.

    """
    def __init__(self, cycle_dim, num_states, concentration=1.1, stickiness=0.0):
        """
        Args:
            transition_matrix[j,k]: prob(hidden(t) = k | hidden(t-1)j)
        """
        self.cycle_dim = cycle_dim
        self.num_states = num_states
        concentration = \
            concentration * jnp.ones((num_states, num_states)) + \
            stickiness * jnp.eye(num_states)
        concentration = jnp.tile(jnp.expand_dims(concentration, axis=0), reps=(self.cycle_dim, 1, 1))  # todo:
        self.concentration = concentration

    def distribution(self, params, state, inputs=None):
        cycle_index, state_index = state
        return tfd.JointDistributionSequential([
            tfd.Deterministic((cycle_index + 1) % self.cycle_dim),
            tfd.Categorical(probs=params.transition_matrix[cycle_index.astype(jnp.uint32), state_index])
        ])

    def initialize(self, key=None, method="prior", transition_matrix=None):
        """Initialize the model parameters and their corresponding properties.

        Args:
            key (_type_, optional): _description_. Defaults to None.
            method (str, optional): _description_. Defaults to "prior".
            transition_matrix (_type_, optional): _description_. Defaults to None.

        Returns:
            _type_: _description_
        """
        if transition_matrix is None:
            this_key, key = jr.split(key)
            transition_matrix = tfd.Dirichlet(self.concentration).sample(seed=this_key)

        # Package the results into dictionaries
        params = ParamsCyclingHMMTransitions(transition_matrix=transition_matrix)
        props = ParamsCyclingHMMTransitions(transition_matrix=ParameterProperties(constrainer=tfb.SoftmaxCentered()))
        return params, props

    def collect_suff_stats(self, params, posterior, inputs=None):
        # return posterior.trans_probs
        num_timesteps = posterior.trans_probs.shape[0]
        trans_probs = jnp.stack([
            posterior.trans_probs[jnp.arange(i, num_timesteps, step=self.cycle_dim)].sum(axis=0)  # todo:
            for i in range(self.cycle_dim)])
        return trans_probs

class ParamsCyclingHMMEmissions(NamedTuple):
    pass

class ParamsCyclingHMM(NamedTuple):
    initial: ParamsStandardHMMInitialState
    transitions: ParamsCyclingHMMTransitions
    emissions: None

class CyclingHMMEmissions(CategoricalHMMEmissions):

    def __init__(self,
                 num_states,
                 emission_dim):
        self.num_states = num_states
        self.emission_dim = emission_dim

    @property
    def emission_shape(self):
        return [(self.emission_dim,), (self.emission_dim,)]  # todo:
        # return (2, self.emission_dim,)
        # return ((self.emission_dim,), (self.emission_dim,))
        # return (self.emission_dim,)

    def distribution(self, params, state, inputs=None):
        cycle_index, state_index = state
        return tfd.JointDistributionSequential([
            tfd.Deterministic([cycle_index]),
            tfd.Deterministic([state_index])
        ])
        # return tfd.Deterministic(state_index)  # todo:
        # return tfd.Deterministic([state_index])  # todo:
        # return tfd.Independent(
        #   tfd.Deterministic([state_index]),
        #   reinterpreted_batch_ndims=0)

    def log_prior(self, params):
        # todo: the emissions are fully deterministic, so the log prior is 0, right?
        return 0

    def _compute_conditional_logliks(self, params, emissions, inputs=None):
        # todo:
        a = emissions[1].reshape((-1,))
        a = jnp.round(a).astype(jnp.uint32)
        a = one_hot(a, num_classes=self.num_states)
        return jnp.where(a, jnp.zeros_like(a), jnp.full_like(a, fill_value=-jnp.inf))

    def initialize(self, key=jr.PRNGKey(0), method="prior"):
        """Initialize the model parameters and their corresponding properties.

        You can either specify parameters manually via the keyword arguments, or you can have
        them set automatically. If any parameters are not specified, you must supply a PRNGKey.
        Parameters will then be sampled from the prior (if `method==prior`).

        Note: in the future we may support more initialization schemes, like K-Means.

        Args:
            key (PRNGKey, optional): random number generator for unspecified parameters. Must not be None if there are any unspecified parameters. Defaults to jr.PRNGKey(0).
            method (str, optional): method for initializing unspecified parameters. Currently, only "prior" is allowed. Defaults to "prior".

        Returns:
            params: nested dataclasses of arrays containing model parameters.
            props: a nested dictionary of ParameterProperties to specify parameter constraints and whether or not they should be trained.
        """

        # Add parameters to the dictionary
        params = ParamsCyclingHMMEmissions()
        props = ParamsCyclingHMMEmissions()
        return params, props

    def collect_suff_stats(self, params, posterior, emissions, inputs=None):
        # todo: the emissions are fully deterministic, so return empty dict?
        return dict()

    def m_step(self, params, props, batch_stats, m_step_state):
        # todo: the emissions are fully deterministic, so nothing to maximize?
        return params, m_step_state

class CyclingCategoricalHMM(HMM):
    r"""An HMM with conditionally independent categorical emissions.

    Let $y_t \in \{1,\ldots,C\}^N$ denote a vector of $N$ conditionally independent
    categorical emissions from $C$ classes at time $t$. In this model,the emission
    distribution is,

    $$p(y_t \mid z_t, \theta) = \prod_{n=1}^N \mathrm{Cat}(y_{tn} \mid \theta_{z_t,n})$$
    $$p(\theta) = \prod_{k=1}^K \prod_{n=1}^N \mathrm{Dir}(\theta_{k,n}; \gamma 1_C)$$

    with $\theta_{k,n} \in \Delta_C$ for $k=1,\ldots,K$ and $n=1,\ldots,N$ are the
    *emission probabilities* and $\gamma$ is their prior concentration.

    :param cycle_dim: number of $K-K$ transition matrices to cycle through
    :param num_states: number of discrete states $K$
    :param emission_dim: number of conditionally independent emissions $N$
    :param initial_probs_concentration: $\alpha$
    :param transition_matrix_concentration: $\beta$
    :param transition_matrix_stickiness: optional hyperparameter to boost the concentration on the diagonal of the transition matrix.

    """

    def __init__(self,
                 cycle_dim: int,
                 num_states: int,
                 emission_dim: int,
                 initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]] = 1.1,
                 transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]] = 1.1,
                 transition_matrix_stickiness: Scalar = 0.0):
        self.cycle_dim = cycle_dim
        self.emission_dim = emission_dim
        initial_component = CyclingHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration)
        transition_component = CyclingHMMTransitions(cycle_dim, num_states,
                                                     concentration=transition_matrix_concentration,
                                                     stickiness=transition_matrix_stickiness)
        emission_component = CyclingHMMEmissions(num_states, emission_dim)
        super().__init__(num_states, initial_component, transition_component, emission_component)

    def initialize(self,
                   key: jr.PRNGKey = jr.PRNGKey(0),
                   method: str = "prior",
                   initial_probs: Optional[Float[Array, "num_states"]] = None,
                   transition_matrix: Optional[Float[Array, "cycle_dim num_states num_states"]] = None,
                   ) -> Tuple[ParameterSet, PropertySet]:
        """Initialize the model parameters and their corresponding properties.

        You can either specify parameters manually via the keyword arguments, or you can have
        them set automatically. If any parameters are not specified, you must supply a PRNGKey.
        Parameters will then be sampled from the prior (if `method==prior`).

        Note: in the future we may support more initialization schemes, like K-Means.

        Args:
            key (PRNGKey, optional): random number generator for unspecified parameters. Must not be None if there are any unspecified parameters. Defaults to None.
            method (str, optional): method for initializing unspecified parameters. Currently, only "prior" is allowed. Defaults to "prior".
            initial_probs (array, optional): manually specified initial state probabilities. Defaults to None.
            transition_matrix (array, optional): manually specified transition matrix. Defaults to None.

        Returns:
            Model parameters and their properties.
        """
        key1, key2, key3 = jr.split(key, 3)
        params, props = dict(), dict()
        params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method,
                                                                                initial_probs=initial_probs)
        params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method,
                                                                                           transition_matrix=transition_matrix)
        params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method)
        return ParamsCyclingHMM(**params), ParamsCyclingHMM(**props)

    def e_step(self, params, emissions, inputs=None):
        """The E-step computes expected sufficient statistics under the
        posterior. In the generic case, we simply return the posterior itself.
        """
        initial_distribution, transition_matrix, log_likelihoods = self._inference_args(params, emissions, inputs)
        transition_fn = lambda index: transition_matrix[index % self.cycle_dim]
        posterior = hmm_two_filter_smoother(initial_distribution=initial_distribution, log_likelihoods=log_likelihoods,
                                            transition_matrix=None,  # None because we use `transition_fn`
                                            transition_fn=transition_fn)

        initial_stats = self.initial_component.collect_suff_stats(params.initial, posterior, inputs)
        transition_stats = self.transition_component.collect_suff_stats(params.transitions, posterior, inputs)
        emission_stats = self.emission_component.collect_suff_stats(params.emissions, posterior, emissions, inputs)
        return (initial_stats, transition_stats, emission_stats), posterior.marginal_loglik

    def fit_sgd(
        self,
        params: ParameterSet,
        props: PropertySet,
        emissions: Union[Float[Array, "num_timesteps emission_dim"],
                         Float[Array, "num_batches num_timesteps emission_dim"]],
        inputs: Optional[Union[Float[Array, "num_timesteps input_dim"],
                               Float[Array, "num_batches num_timesteps input_dim"]]]=None,
        optimizer: optax.GradientTransformation=optax.adam(1e-3),
        batch_size: int=1,
        num_epochs: int=50,
        shuffle: bool=False,
        key: jr.PRNGKey=jr.PRNGKey(0)
    ) -> Tuple[ParameterSet, Float[Array, "niter"]]:
        r"""Compute parameter MLE/ MAP estimate using Stochastic Gradient Descent (SGD).

        SGD aims to find parameters that maximize the marginal log probability,

        $$\theta^\star = \mathrm{argmax}_\theta \; \log p(y_{1:T}, \theta \mid u_{1:T})$$

        by minimizing the _negative_ of that quantity.

        *Note:* ``emissions`` *and* ``inputs`` *can either be single sequences or batches of sequences.*

        On each iteration, the algorithm grabs a *minibatch* of sequences and takes a gradient step.
        One pass through the entire set of sequences is called an *epoch*.

        Args:
            params: model parameters $\theta$
            props: properties specifying which parameters should be learned
            emissions: one or more sequences of emissions
            inputs: one or more sequences of corresponding inputs
            optimizer: an `optax` optimizer for minimization
            batch_size: number of sequences per minibatch
            num_epochs: number of epochs of SGD to run
            key: a random number generator for selecting minibatches
            verbose: whether or not to show a progress bar

        Returns:
            tuple of new parameters and losses (negative scaled marginal log probs) over the course of SGD iterations.

        """
        # Make sure the emissions and inputs have batch dimensions
        batch_emissions = ensure_array_has_batch_dim(emissions, self.emission_shape)
        batch_inputs = ensure_array_has_batch_dim(inputs, self.inputs_shape)

        unc_params = to_unconstrained(params, props)

        def _loss_fn(unc_params, minibatch):
            """Default objective function."""
            params = from_unconstrained(unc_params, props)
            minibatch_emissions, minibatch_inputs = minibatch
            num_timesteps = len(batch_emissions[0])
            scale = num_timesteps / len(minibatch_emissions[0])
            minibatch_lls = jax.vmap(partial(self.marginal_log_prob, params))(minibatch_emissions, minibatch_inputs)
            lp = self.log_prior(params) + minibatch_lls.sum() * scale
            return -lp / batch_emissions[0].size

        dataset = (batch_emissions, batch_inputs)
        unc_params, losses = run_sgd(_loss_fn,
                                     unc_params,
                                     dataset,
                                     optimizer=optimizer,
                                     batch_size=batch_size,
                                     num_epochs=num_epochs,
                                     shuffle=shuffle,
                                     key=key)

        params = from_unconstrained(unc_params, props)
        return params, losses

def main():

    key = jr.PRNGKey(2)
    key, subkey = jr.split(key)

    # Define the model parameters
    cycle_dim = 3
    num_emissions = 1  # Only one emission at a time
    num_observable_states = 2

    # Initialize the parameters for the cycling model
    initial_probs = jnp.full((num_observable_states), 1.0/num_observable_states)
    # transition_matrix = jnp.full((cycle_dim, num_observable_states, num_observable_states), 1.0)

    # initial_probs = None
    transition_matrix = None
    key, subkey = jr.split(key)
    transition_matrix = jax.nn.softmax(.2*jr.normal(subkey, shape=(cycle_dim, num_observable_states, num_observable_states)), axis=-1)

    # Construct the CyclingCategoricalHMM
    hmm = CyclingCategoricalHMM(cycle_dim, num_observable_states, num_emissions,
                                # initial_probs_concentration=1.1,  # # todo: default 1.1
                                # transition_matrix_concentration=.5,  # todo: default 1.1
                                # transition_matrix_stickiness=0.0  # todo: default 0.0
                                )

    # Initialize the parameters struct with known values
    key, subkey = jr.split(key)
    params, _ = hmm.initialize(subkey,
                               initial_probs=initial_probs,
                               transition_matrix=transition_matrix,
                               )

    # Generate synthetic data
    num_batches = 10000
    num_timesteps = 100

    key, subkey = jr.split(key)
    batch_states, batch_emissions = \
        jax.vmap(partial(hmm.sample, params, num_timesteps=num_timesteps))(
            jr.split(subkey, num_batches))

    print(f"batch_states.shape:    {batch_states[1].shape}")
    print(f"batch_emissions.shape: {batch_emissions[1].shape}")

    # note that batch_states[0] and batch_emissions[0] correspond
    # to the cycle_dim index, i.e. $$timestep % cycle_dim$$

    def print_params(params):
        jnp.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
        print("Initial probs:")
        print(params.initial.probs)
        print("Transition matrices:")
        for i in range(cycle_dim):
            print(f"Transition matrix {i}:")
            print(params.transitions.transition_matrix[i])
        print('')

    print('True Params:')
    print_params(params)

    # Train the model using EM
    num_iters = 20
    key, subkey = jr.split(key)
    em_params, em_param_props = hmm.initialize(subkey)
    em_params, log_probs = hmm.fit_em(em_params,
                                      em_param_props,
                                      batch_emissions,
                                      num_iters=num_iters)

    sgd_params, sgd_param_props = hmm.initialize(key)
    sgd_key, key = jr.split(key)
    sgd_params, sgd_losses = hmm.fit_sgd(sgd_params,
                                         sgd_param_props,
                                         batch_emissions,
                                         optimizer=optax.sgd(learning_rate=1e-2, momentum=0.95),
                                         batch_size=num_batches//10,
                                         num_epochs=num_iters,
                                         key=sgd_key)

    print('SGD Params:')
    print_params(sgd_params)

    # Compute the "losses" from EM
    em_losses = -log_probs / batch_emissions[1].size

    # Compute the loss if you used the parameters that generated the data
    true_loss = jax.vmap(partial(hmm.marginal_log_prob, params))(batch_emissions).sum()
    true_loss += hmm.log_prior(params)
    true_loss = -true_loss / batch_emissions[1].size

    # Plot the learning curve
    plt.plot(sgd_losses, label=f"SGD (mini-batch size = {num_batches//10})")
    plt.plot(em_losses, label="EM")
    plt.axhline(true_loss, color='k', linestyle=':', label="True Params")
    plt.legend()
    plt.xlim(-2, num_iters)
    plt.xlabel("epoch")
    plt.ylabel("loss")
    _ = plt.title("Learning Curve")

    print('EM learned parameters:')
    print_params(em_params)

    plt.show()

def simple_main():
    key = jr.PRNGKey(0)
    key, subkey = jr.split(key)

    cycle_dim = 3
    num_states = 2
    num_emissions = 1
    num_timesteps = 100  # Simplified for testing

    initial_probs = jax.nn.softmax(.2*jr.normal(subkey, (num_states,)))
    print("Initial Probabilities:", initial_probs)
    print("Sum of Initial Probabilities:", jnp.sum(initial_probs))

    key, subkey = jr.split(key)
    transition_matrix = jax.nn.softmax(2*jr.normal(subkey, (cycle_dim, num_states, num_states)), axis=-1)
    print("Transition Matrices:")
    for i in range(cycle_dim):
        print(f"Transition matrix {i}:\n{transition_matrix[i]}")

    hmm = CyclingCategoricalHMM(cycle_dim, num_states, num_emissions)

    key, subkey = jr.split(key)
    params, _ = hmm.initialize(subkey, initial_probs=initial_probs, transition_matrix=transition_matrix)

    key, subkey = jr.split(key)
    states, emissions = hmm.sample(params, num_timesteps=num_timesteps, key=subkey)
    print("Sampled States:\n", states)
    print("Sampled Emissions:\n", emissions)

    (initial_stats, transition_stats, emission_stats), log_likelihood = hmm.e_step(params, emissions)

    log_likelihood /= num_timesteps

    print("Initial Stats:", initial_stats)
    print("Transition Stats:", transition_stats)
    print("Emission Stats:", emission_stats)
    print("Log Likelihood:", log_likelihood)

if __name__ == '__main__':
    main()
    # simple_main()

output:

image

batch_states.shape:    (10000, 100)
batch_emissions.shape: (10000, 100, 1)
True Params:
Initial probs:
[0.500 0.500]
Transition matrices:
Transition matrix 0:
[[0.533 0.467]
 [0.530 0.470]]
Transition matrix 1:
[[0.492 0.508]
 [0.426 0.574]]
Transition matrix 2:
[[0.526 0.474]
 [0.593 0.407]]

SGD Params:
Initial probs:
[0.070 0.930]
Transition matrices:
Transition matrix 0:
[[0.213 0.787]
 [0.964 0.036]]
Transition matrix 1:
[[0.340 0.660]
 [0.720 0.280]]
Transition matrix 2:
[[0.518 0.482]
 [0.522 0.478]]

EM learned parameters:
Initial probs:
[0.497 0.503]
Transition matrices:
Transition matrix 0:
[[0.535 0.465]
 [0.531 0.469]]
Transition matrix 1:
[[0.491 0.509]
 [0.426 0.574]]
Transition matrix 2:
[[0.526 0.474]
 [0.593 0.407]]