jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.29k stars 590 forks source link

0.14.9 - model.bake() - Issue with Probability Normalization in Hidden Markov Model #1072

Closed TheCrafft closed 7 months ago

TheCrafft commented 7 months ago

Friday afternoon and stuck :(.

I've encountered an issue while working with a Hidden Markov Model (HMM) where each of the 30 states can transition to every other state. The transition probabilities between states are unique, yet the sum of probabilities for each state's transitions to other states consistently adds up to 1.

Here's an example of the transition probabilities for one of the states:

probabilities = np.array([0.01655172, 0.02068966, 0.02068966, 0.02068966, 0.02068966,
                          0.02068966, 0.01655172, 0.02068966, 0.02758621, 0.02758621,
                          0.02758621, 0.02758621, 0.01655172, 0.02068966, 0.02758621,
                          0.04137931, 0.04137931, 0.04137931, 0.01655172, 0.02068966,
                          0.02758621, 0.04137931, 0.08275862, 0.08275862, 0.01655172,
                          0.02068966, 0.02758621, 0.04137931, 0.08275862, 0.08275862])

However, after applying the baking process to finalize the model, there are discrepancies in the probabilities.

I suspect the issue might be related to the following section of code:

# Check and normalize outgoing edges in the model
if merge in ['all', 'partial']:
    for state in list(self.graph.nodes()):
        out_edges = round(sum(numpy.e**x['probability']
                        for x in self.graph.adj[state].values()), 8)

        # Normalize edges to 1 if not already normalized
        if out_edges != 1. and state != self.end:
            if verbose:
                print("{} : {} summed to {}, normalized to 1.0"\
                        .format(self.name, state.name, out_edges))

            # Reweight the edges for probability sum to 1
            for edge in self.graph.adj[state].values():
                edge['probability'] = edge['probability'] - _log(out_edges)

I'm unclear why the code performs a logarithmic sum exponentiation and checks if the sum equals 1.0, subsequently normalizing the probabilities. Given that the initial probabilities are already appropriately set and sum to 1.0 for each state's transitions, I'm puzzled as to why this normalization step is required.

Could someone help clarify the purpose of this normalization step? It seems to alter already valid probabilities and might be contributing to the discrepancies observed after baking the model.

Thank you in advance for any insights or guidance on resolving this issue.

The transition matrix for the example state provided.

np.array([0.02803738, 0.05607477, 0.02803738, 0.01869159, 0.02803738,
       0.05607477, 0.05607477, 0.05607477, 0.02803738, 0.01869159,
       0.02803738, 0.05607477, 0.02803738, 0.02803738, 0.02803738,
       0.02803738, 0.01869159, 0.01869159, 0.01869159, 0.01869159,
       0.01869159, 0.01869159, 0.05607477, 0.01869159, 0.05607477,
       0.02803738, 0.01869159, 0.02803738, 0.05607477, 0.05607477,
       0.        , 0.        ])
jmschrei commented 7 months ago

Basically, this was added to give people more flexibility when programmatically constructing HMMs. Rather than being forced to make sure your code created each transition with a sum of 1, it'd just handle everything afterward internally. There is a terminology issue where it's not actually the probabilities being exponentiated, it's log probabilities, just incorrectly called probabilities.

Something worth considering is that states can be re-ordered as part of the baking process. You should look at the name of the states with [state.name for state in model.states] to ensure you're looking at the same state.

TheCrafft commented 7 months ago

Thanks! The states were indeed re-ordered during the baking process! Leaving this here for others

state_order = {
    idx: int(re.search(r"\d+", state.name).group())
    for idx, state in enumerate(model.states[:-2])
}

sorted_state_order = sorted(state_order, key=state_order.get)