Describe the bug
I try to fit a GMM with a lot of populations but get the following errors when the number of population is too high.
I understand this is due to the fit_predict method predicting zero elements for some of the distributions of the GMM.
The subsequent attempt by the function initialize to fit the data X[idx] with idx being a zero length array raises an error.
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[78], line 1
----> 1 model.fit(torch.from_numpy(x_hat[ind][:,np.newaxis]))
File ~/.local/lib/python3.10/site-packages/pomegranate/gmm.py:245, in GeneralMixtureModel.fit(self, X, sample_weight, priors)
242 start_time = time.time()
244 last_logp = logp
--> 245 logp = self.summarize(X, sample_weight=sample_weight,
246 priors=priors)
248 if i > 0:
249 improvement = logp - last_logp
File ~/.local/lib/python3.10/site-packages/pomegranate/gmm.py:308, in GeneralMixtureModel.summarize(self, X, sample_weight, priors)
306 X = _check_parameter(_cast_as_tensor(X), "X", ndim=2)
307 if not self._initialized:
--> 308 self._initialize(X, sample_weight=sample_weight)
310 sample_weight = _reshape_weights(X, _cast_as_tensor(sample_weight,
311 dtype=torch.float32), device=self.device)
313 e = self._emission_matrix(X, priors=priors)
File ~/.local/lib/python3.10/site-packages/pomegranate/gmm.py:162, in GeneralMixtureModel._initialize(self, X, sample_weight)
159 for i in range(self.k):
160 idx = y_hat == i
--> 162 self.distributions[i].fit(X[idx], sample_weight=sample_weight[idx])
163 self.priors[i] = idx.type(torch.float32).mean()
165 self._initialized = True
File ~/.local/lib/python3.10/site-packages/pomegranate/distributions/_distribution.py:67, in Distribution.fit(self, X, sample_weight)
66 def fit(self, X, sample_weight=None):
---> 67 self.summarize(X, sample_weight=sample_weight)
68 self.from_summaries()
69 return self
File ~/.local/lib/python3.10/site-packages/pomegranate/distributions/normal.py:258, in Normal.summarize(self, X, sample_weight)
255 if self.frozen == True:
256 return
--> 258 X, sample_weight = super().summarize(X, sample_weight=sample_weight)
259 X = _cast_as_tensor(X, dtype=self.means.dtype)
261 if self.covariance_type == 'full':
File ~/.local/lib/python3.10/site-packages/pomegranate/distributions/_distribution.py:73, in Distribution.summarize(self, X, sample_weight)
71 def summarize(self, X, sample_weight=None):
72 if not self._initialized:
---> 73 self._initialize(len(X[0]))
75 X = _cast_as_tensor(X)
76 _check_parameter(X, "X", ndim=2, shape=(-1, self.d),
77 check_parameter=self.check_data)
IndexError: index 0 is out of bounds for dimension 0 with size 0
To Reproduce
A minimally reproducible example although rather uninteresting :
import torch
from pomegranate.gmm import GeneralMixtureModel
from pomegranate.distributions import Normal
dl = []
for i in range(40):
p = peaks[i]
dl.append(Normal([0,1]).double())
model = GeneralMixtureModel(dl)
model.fit(torch.randint(1,[100,1]))
Describe the bug I try to fit a GMM with a lot of populations but get the following errors when the number of population is too high. I understand this is due to the
fit_predict
method predicting zero elements for some of the distributions of the GMM. The subsequent attempt by the functioninitialize
to fit the dataX[idx]
withidx
being a zero length array raises an error.To Reproduce
A minimally reproducible example although rather uninteresting :