jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.35k stars 589 forks source link

[BUG] GMM.sample(): probabilities do not sum to 1 #968

Closed profPlum closed 1 year ago

profPlum commented 2 years ago

Describe the bug When fitting a GMM to a dataset using Bernoulli and Normal distributions, and then calling sample() it will through an error: probabilities do not sum to 1

To Reproduce

import pandas as pd
import numpy as np
import pomegranate as pgm
from pomegranate import *
# dataset from here: https://www.kaggle.com/datasets/alexteboul/diabetes-health-indicators-dataset
data = pd.read_csv(dir + 'diabetes_binary_health_indicators_BRFSS2015.csv')

# verified to work! 4/3/22
# prepare df for PGMs which require discrete data only!
def make_df_categorical(data, max_cols=9, required_cols=[], max_vals_for_categorical=15):
  unique_values = {col: len(data[col].unique()) for col in data}
  print(unique_values)
  categorical_candidates = [col for col in data if len(data[col].unique()) < max_vals_for_categorical]
  categorical_candidates = list(set(categorical_candidates) - set(required_cols))
  categorical_candidates = list(np.random.choice(categorical_candidates, max_cols-len(required_cols), replace=False)) + list(required_cols)
  data_down_sample = data[categorical_candidates].apply(lambda x: x.factorize()[0])
  return data_down_sample

def fit_GMM(data, model=None, n_components=5, default_dist=NormalDistribution):
  is_binom = lambda x: np.all(np.isin(x, [1, 0, True, False]))
  distributions = [(BernoulliDistribution if is_binom(data[x]) else default_dist) for x in data.columns]
  GMM = GeneralMixtureModel.from_samples(distributions, n_components=n_components, X=data)
  return GMM
GMM = fit_GMM(data_down_sample)#, default_dist=DiscreteDistribution)
GMM.sample()

Most of the time this code will say: probabilities do not sum to 1. It appears this only happens when I need to use the default_distribution (which is Gaussian), it will not happen if the down sampled columns are all binary for example. NOTE: dataset only has binary and integer data.

jmschrei commented 1 year ago

Thank you for opening an issue. pomegranate has recently been rewritten from the ground up to use PyTorch instead of Cython (v1.0.0), and so all issues are being closed as they are likely out of date. Please re-open or start a new issue if a related issue is still present in the new codebase.