jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.35k stars 589 forks source link

NotImplementedError: zero-inflated poisson HMM #1107

Closed KunFang93 closed 3 months ago

KunFang93 commented 3 months ago

Hi Jacob,

I wondered if it is possible to implement zero-inflated poisson HMM in currently version of pomegranate? I saw an old unsolved issue with NotImplementedError. I tried

import numpy as np
import pandas as pd

def zero_inflated_poisson(size, lam=3, zero_prob=0.5):
    poisson_part = np.random.poisson(lam, size)
    zero_inflation = np.random.binomial(1, zero_prob, size)
    return poisson_part * zero_inflation

num_observations = 1000
num_features = 11
zero_prob = 0.5

data = np.array([zero_inflated_poisson(num_observations, zero_prob=zero_prob) for _ in range(num_features)]).T
data = data.reshape(1,1000,11)
model = DenseHMM([ZeroInflated(Poisson())]*20, verbose=True, random_state=34)
model.fit(data) 

but got same error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/kfang/.conda/envs/mamba/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/_base.py", line 604, in fit
    logp += self.summarize(X_, sample_weight=w_, priors=p_).sum()
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kfang/.conda/envs/mamba/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/dense_hmm.py", line 543, in summarize
    X, emissions, sample_weight = super().summarize(X,
                                  ^^^^^^^^^^^^^^^^^^^^
  File "/home/kfang/.conda/envs/mamba/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/_base.py", line 681, in summarize
    emissions = _check_inputs(self, X, emissions, priors)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kfang/.conda/envs/mamba/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/_base.py", line 28, in _check_inputs
    emissions = model._emission_matrix(X, priors=priors)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kfang/.conda/envs/mamba/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/_base.py", line 287, in _emission_matrix
    logp = node.log_probability(X)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kfang/.conda/envs/mamba/envs/pomegranate/lib/python3.12/site-packages/pomegranate/distributions/_distribution.py", line 64, in log_probability
    raise NotImplementedError
NotImplementedError

Thanks for your time and help!

Best, Kun

jmschrei commented 3 months ago

Sorry for the delay in responding to these issues. The answer is that conceptually it's very simple to add this in but that, implementation-wise, I struggled a little bit because the zero-inflated wrapper needs access to the per-feature log probabilities but my current implementation only gives the total log probability across all features. I think I need to just implement specific ZIP and ZINB distributions, as those are the most widely used anyway. I don't know when I'd get a chance to do that, though, so I'd recommend you try implementing them yourself based on the Poisson and ZeroInflated wrapper I provided.

KunFang93 commented 3 months ago

Not a problem at all and thank you for the reply! I wondered if you could shed some lights on how to implement by the poisson and zero inflated wrapper? Is something like

states = []
for i in range(20): # assume 20 number of stae
    cur_d = ZeroInflated(Poisson())
    cur_d.fit(data)
    state.append(cur_d)
model = DenseHMM(states)
model.fit(data)

Any suggestion is greatly appreciated! Thank you so much for your time and help again~

Best, Kun

jmschrei commented 3 months ago

You'd need to create a new distribution object that applies the formula in the zero inflated wrapper to the log probabilities BEFORE they are summed. The current objects only work if you only have one dimension.

KunFang93 commented 3 months ago

Got it. I will delve into it. Thanks for your instruction!