jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
MIT License
3.29k stars 590 forks source link

log_probability not implemented for zero_inflated poisson HMM #1069

Open hfealr1111 opened 7 months ago

hfealr1111 commented 7 months ago

Hi, I am trying to infer and analyze hidden states in neuron spikings with ZIP-HMM uninitialized and fit the model to data. model = DenseHMM([ZeroInflated(Poisson()), ZeroInflated(Poisson()), ZeroInflated(Poisson())], max_iter=1000, verbose=True) However, it shows that

/usr/local/lib/python3.10/dist-packages/pomegranate/distributions/ in log_probability(self, X)
     63         def log_probability(self, X):
---> 64                 raise NotImplementedError
     66         def fit(self, X, sample_weight=None):


I understand that zero_inflated is a wrapper so it shouldn't have any dedicated log_probability function. So, I wish to confirm with you that ZeroInflated(Poisson()) could be used in hmm this way. If so, I wish you could kindly provide a solution to this. Thanks in advance!

jmschrei commented 7 months ago

Yes, this is an error on my side. I will look into a solution.

hfealr1111 commented 7 months ago

Thanks, I really appreciate it!

KunFang93 commented 1 week ago


I met with the similar issue:

Traceback (most recent call last):
  File "/Applications/PyCharm", line 364, in runcode
    coro = func()
  File "<input>", line 1, in <module>
  File "/Users/kfang/miniconda3/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/", line 604, in fit
    logp += self.summarize(X_, sample_weight=w_, priors=p_).sum()
  File "/Users/kfang/miniconda3/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/", line 543, in summarize
    X, emissions, sample_weight = super().summarize(X, 
  File "/Users/kfang/miniconda3/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/", line 681, in summarize
    emissions = _check_inputs(self, X, emissions, priors)
  File "/Users/kfang/miniconda3/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/", line 28, in _check_inputs
    emissions = model._emission_matrix(X, priors=priors)
  File "/Users/kfang/miniconda3/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/", line 287, in _emission_matrix
    logp = node.log_probability(X)
  File "/Users/kfang/miniconda3/envs/pomegranate/lib/python3.12/site-packages/pomegranate/distributions/", line 64, in log_probability
    raise NotImplementedError

My model is:

hmm = pg.hmm.DenseHMM(

and the calling line

x =

where obs_final is a np array with shape (1, 1119101, 11).

I wondered if I did it correctly? Thanks in advance!

Best, Kun