jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.29k stars 590 forks source link

log_probability not implemented for zero_inflated poisson HMM #1069

Open hfealr1111 opened 7 months ago

hfealr1111 commented 7 months ago

Hi, I am trying to infer and analyze hidden states in neuron spikings with ZIP-HMM uninitialized and fit the model to data. model = DenseHMM([ZeroInflated(Poisson()), ZeroInflated(Poisson()), ZeroInflated(Poisson())], max_iter=1000, verbose=True) However, it shows that

/usr/local/lib/python3.10/dist-packages/pomegranate/distributions/_distribution.py in log_probability(self, X)
     62 
     63         def log_probability(self, X):
---> 64                 raise NotImplementedError
     65 
     66         def fit(self, X, sample_weight=None):

NotImplementedError: 

I understand that zero_inflated is a wrapper so it shouldn't have any dedicated log_probability function. So, I wish to confirm with you that ZeroInflated(Poisson()) could be used in hmm this way. If so, I wish you could kindly provide a solution to this. Thanks in advance!

jmschrei commented 7 months ago

Yes, this is an error on my side. I will look into a solution.

hfealr1111 commented 7 months ago

Thanks, I really appreciate it!

KunFang93 commented 1 week ago

Hi,

I met with the similar issue:

Traceback (most recent call last):
  File "/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevconsole.py", line 364, in runcode
    coro = func()
           ^^^^^^
  File "<input>", line 1, in <module>
  File "/Users/kfang/miniconda3/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/_base.py", line 604, in fit
    logp += self.summarize(X_, sample_weight=w_, priors=p_).sum()
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kfang/miniconda3/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/dense_hmm.py", line 543, in summarize
    X, emissions, sample_weight = super().summarize(X, 
                                  ^^^^^^^^^^^^^^^^^^^^
  File "/Users/kfang/miniconda3/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/_base.py", line 681, in summarize
    emissions = _check_inputs(self, X, emissions, priors)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kfang/miniconda3/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/_base.py", line 28, in _check_inputs
    emissions = model._emission_matrix(X, priors=priors)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kfang/miniconda3/envs/pomegranate/lib/python3.12/site-packages/pomegranate/hmm/_base.py", line 287, in _emission_matrix
    logp = node.log_probability(X)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kfang/miniconda3/envs/pomegranate/lib/python3.12/site-packages/pomegranate/distributions/_distribution.py", line 64, in log_probability
    raise NotImplementedError
NotImplementedError

My model is:

hmm = pg.hmm.DenseHMM(
    [ZeroInflated(Poisson())]*11,
    random_state=34,
    max_iter=200
)

and the calling line

x = hmm.fit(obs_final)

where obs_final is a np array with shape (1, 1119101, 11).

I wondered if I did it correctly? Thanks in advance!

Best, Kun