jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.29k stars 590 forks source link

[BUG] Transition matrix doesn't seem to change in HiddenMarkovModel.fit #1010

Closed dave-shore closed 1 year ago

dave-shore commented 1 year ago

Hello, I am trying to use pomegranate to fit and evaluate a Hidden Markov Model on CoNLL2003 data. When fitting the model, however, each state emission distribution gets updated, but the transition matrix doesn't. Another issue, likely associated to the previous one, is that predictions from Viterbi algorithm always return the same pattern: constant until the last token of the input sequence.

I provide here code snippet to replicate the results. As far as you can see, am I doing something wrong?

from datasets import load_dataset
from pomegranate import *
import numpy as np
import pandas as pd

corpus_train = load_dataset("conll2003", split = "train")
df_train = corpus_train.to_pandas()

# Define Entity tag map (available at https://huggingface.co/datasets/conll2003)
NERtags = {
    'O': 0, 
    'B-PER': 1, 
    'I-PER': 2, 
    'B-ORG': 3, 
    'I-ORG': 4, 
    'B-LOC': 5, 
    'I-LOC': 6, 
    'B-MISC': 7, 
    'I-MISC': 8
}

# And its inverse map
NERtags_inv = {n:c for c,n in NERtags.items()}
df_train['NER_tags'] = df_train['ner_tags'].apply(lambda L: [NERtags_inv[n] for n in L])

# Defining states
states = list(NERtags.keys())

# Defining starting probabilities
starting_probs = {s:int(not s.startswith("I")) for s in states}

# Defining transition probabilities
def fine_transition(s1, s2):

    if s2.startswith("I"):

        if s1 == "O":
            return 0

        else:
            s1_cat = s1.split("-")[-1]
            s2_cat = s2.split("-")[-1]
            return int(s1_cat == s2_cat)

    else:
        return 1

transition_matrix = {
    s1:{s2:fine_transition(s1, s2) for s2 in states}
for s1 in states}

# here function sent2feat is just an arbitrary function to get word features
X_train = df_train.apply(lambda row: sent2feat(row['tokens'], pos = row['pos_tags']), axis = 1)
y_train = df_train['NER_tags']

# Transform Series into DataFrame to get unique values of each feature
X_train_df = pd.DataFrame(X_train.explode().tolist())
X_train_df = X_train_df.applymap(str)
N_values = X_train_df.nunique()

### INITIALIZATION ###
hmm = HiddenMarkovModel()

# Univariate distribution of each feature (uniform to be as agnostic as possible)
univar_distributions = [
    DiscreteDistribution({v:1/N_values[col] for v in X_train_df[col].unique()}) 
for col in X_train_df.columns]

# State emission (prior) distributions (already sort distributions to match the pomegranate order)
init_emissions = [
    State(IndependentComponentsDistribution(list(univar_distributions)), name = s) for s in sorted(states)
]
hmm.add_states(init_emissions)

# Starting and transition probabilities (defined before as dictionaries)
for s in init_emissions:
    hmm.add_transition(hmm.start, s, starting_probs[s.name])
    for s1 in init_emissions:
        hmm.add_transition(s, s1, transition_matrix[s.name][s1.name])

# Baking the model will adjust the parameters so to have feasible probability distributions
hmm.bake()

### FITTING ###

# Input data must be an iterable of (multidimensional) arrays
input_data = X_train.apply(lambda seq: np.array([[str(f) for f in word.values()] for word in seq])).tolist()

# Fitting by Baum-Welch algorithm
hmm.fit(input_data, algorithm = "baum-welch", max_iterations = 1000)

### EVALUATION ###

for i in range(100):
    [(w, hmm.states[n].name) for w,n in zip(df_train['tokens'].iloc[i], hmm.predict(input_data[i], algorithm = "viterbi")[1:])]

Thank you very much for your kind support in advance.

jmschrei commented 1 year ago

Thank you for opening an issue. pomegranate has recently been rewritten from the ground up to use PyTorch instead of Cython (v1.0.0), and so all issues are being closed as they are likely out of date. Please re-open or start a new issue if a related issue is still present in the new codebase.