jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.35k stars 589 forks source link

[BUG] Fit of GMM returns initialization error because some distributions are not returned by fit_predict #1043

Open Mriv31 opened 1 year ago

Mriv31 commented 1 year ago

Describe the bug I try to fit a GMM with a lot of populations but get the following errors when the number of population is too high. I understand this is due to the fit_predict method predicting zero elements for some of the distributions of the GMM. The subsequent attempt by the function initialize to fit the data X[idx] with idx being a zero length array raises an error.

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[78], line 1
----> 1 model.fit(torch.from_numpy(x_hat[ind][:,np.newaxis]))

File ~/.local/lib/python3.10/site-packages/pomegranate/gmm.py:245, in GeneralMixtureModel.fit(self, X, sample_weight, priors)
    242 start_time = time.time()
    244 last_logp = logp
--> 245 logp = self.summarize(X, sample_weight=sample_weight, 
    246     priors=priors)
    248 if i > 0:
    249     improvement = logp - last_logp

File ~/.local/lib/python3.10/site-packages/pomegranate/gmm.py:308, in GeneralMixtureModel.summarize(self, X, sample_weight, priors)
    306 X = _check_parameter(_cast_as_tensor(X), "X", ndim=2)
    307 if not self._initialized:
--> 308     self._initialize(X, sample_weight=sample_weight)
    310 sample_weight = _reshape_weights(X, _cast_as_tensor(sample_weight, 
    311     dtype=torch.float32), device=self.device)
    313 e = self._emission_matrix(X, priors=priors)

File ~/.local/lib/python3.10/site-packages/pomegranate/gmm.py:162, in GeneralMixtureModel._initialize(self, X, sample_weight)
    159 for i in range(self.k):
    160     idx = y_hat == i
--> 162     self.distributions[i].fit(X[idx], sample_weight=sample_weight[idx])
    163     self.priors[i] = idx.type(torch.float32).mean()
    165 self._initialized = True

File ~/.local/lib/python3.10/site-packages/pomegranate/distributions/_distribution.py:67, in Distribution.fit(self, X, sample_weight)
     66 def fit(self, X, sample_weight=None):
---> 67     self.summarize(X, sample_weight=sample_weight)
     68     self.from_summaries()
     69     return self

File ~/.local/lib/python3.10/site-packages/pomegranate/distributions/normal.py:258, in Normal.summarize(self, X, sample_weight)
    255 if self.frozen == True:
    256     return
--> 258 X, sample_weight = super().summarize(X, sample_weight=sample_weight)
    259 X = _cast_as_tensor(X, dtype=self.means.dtype)
    261 if self.covariance_type == 'full':

File ~/.local/lib/python3.10/site-packages/pomegranate/distributions/_distribution.py:73, in Distribution.summarize(self, X, sample_weight)
     71 def summarize(self, X, sample_weight=None):
     72     if not self._initialized:
---> 73         self._initialize(len(X[0]))
     75     X = _cast_as_tensor(X)
     76     _check_parameter(X, "X", ndim=2, shape=(-1, self.d), 
     77         check_parameter=self.check_data)

IndexError: index 0 is out of bounds for dimension 0 with size 0

To Reproduce

A minimally reproducible example although rather uninteresting :

import torch
from pomegranate.gmm import GeneralMixtureModel
from pomegranate.distributions import Normal

dl = []
for i in range(40):
    p = peaks[i] 
    dl.append(Normal([0,1]).double())

model = GeneralMixtureModel(dl)
model.fit(torch.randint(1,[100,1]))