jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.35k stars 589 forks source link

[BUG] GMM (and normal distribution) fitting doesn't respect frozen parameters #1054

Open NicholasClark opened 1 year ago

NicholasClark commented 1 year ago

I am trying to fit a mixture model of two normal distributions where I freeze the means at 4 and 1.5 and only fit the variances. When I use GeneralMixtureModel, it changes the means (fitted means are 3.98 and 1.47) when it fits to the data anyway. I notice the same issue if I try to fit just one Normal distribution and freeze the mean. I may be doing something wrong, but I've tried it a number of different ways at this point.

Any help would be highly appreciated!

Here is code to reproduce the issue:

import seaborn
import torch
from pomegranate.gmm import GeneralMixtureModel
from pomegranate.distributions import *
import numpy as np
import matplotlib.pyplot as plt

### Generate data for mixture model
np.random.seed(0)
X = np.concatenate([np.random.normal(4, 0.5, size=400),
                    np.random.normal(1.5, 0.5, size=600)])
XX = np.array(X).reshape(-1,1)
XX = torch.tensor(XX).float()
### Fit mixture model and freeze the mean of each distribution
m1 = torch.tensor([4]) ### mean = 4
m2 = torch.tensor([1.5]) ### mean = 1.5
m1.frozen=True
m2.frozen=True
d1 = Normal(means=m1)
d2 = Normal(means=m2)
model = GeneralMixtureModel([d1, d2], verbose=False).fit(XX)
### plot results
x = np.arange(np.min(X), np.max(X), 0.1)
y1 = model.distributions[0].probability(x.reshape(-1, 1))
y2 = model.distributions[1].probability(x.reshape(-1, 1))
y3 = model.probability(x.reshape(-1, 1))
plt.figure(figsize=(6, 3))
plt.hist(X, density=True, bins=30)
plt.plot(x, y1, color = "green", label="Normal1")
plt.plot(x, y2, color = "red", label="Normal2")
plt.plot(x, y3, color = "purple", label="Mixture")
plt.legend(loc=(1.05, 0.4))
plt.tight_layout()
print("mean of Normal1: " + str(round(model.distributions[0].means.item(), 2)))
print("mean of Normal2: " + str(round(model.distributions[1].means.item(), 2)))

histogram_mixture_means

ShaolinXU commented 8 months ago

I think the finest control is the distribution that you define.

I managed to modify the Normal.py from the source code to frozen the means as follows:

remove _update_parameter from def from_summaries(self):

Please point me out if I did it wrong

jmschrei commented 6 months ago

Hi @NicholasClark.

Sorry for the late reply. You are correct that you can freeze individual parameters but you have to do it in a specific way to get it to stick.

First, you added the frozen attribute to the underlying tensor and when this gets taken into the Normal object it gets wrapped into a torch.nn.Parameter, so the frozen attribute is still attached to the tensor but not d.means (since it's a parameter object). You can solve this by adding frozen to the parameters you want frozen after creating the object.

Second, pomegranate does not allow you to incompletely specify distributions as starting points. This should probably raise a warning when it happens. So, what happened is that the distribution did not register as being initialized and so was overwritten in the first step of fitting a GMM. You can get around this by putting some value into covs to completely specify it.

This code works for me and keeps the means frozen. I took out the plotting stuff just because it wasn't relevant for me.


from pomegranate.gmm import GeneralMixtureModel
from pomegranate.distributions import *
import numpy as np
import matplotlib.pyplot as plt

### Generate data for mixture model
np.random.seed(0)
X = np.concatenate([np.random.normal(4, 0.5, size=400),
                    np.random.normal(1.5, 0.5, size=600)])
XX = np.array(X).reshape(-1,1)
XX = torch.tensor(XX).float()
### Fit mixture model and freeze the mean of each distribution
m1 = torch.tensor([4]) ### mean = 4
m2 = torch.tensor([1.5]) ### mean = 1.5

d1 = Normal(means=m1, covs=[1], covariance_type='diag')
d2 = Normal(means=m2, covs=[1], covariance_type='diag')

d1.means.frozen = True
d2.means.frozen = True

model = GeneralMixtureModel([d1, d2], verbose=True).fit(XX)

print(model.distributions[0].means.frozen)
print("mean of Normal1: " + str(round(model.distributions[0].means.item(), 2)))
print("mean of Normal2: " + str(round(model.distributions[1].means.item(), 2)))

When I run this it gives me:

[1] Improvement: 35.6845703125, Time: 0.0005453s
[2] Improvement: 1.496337890625, Time: 0.0005322s
[3] Improvement: 0.07568359375, Time: 0.0005288s
True
mean of Normal1: 4
mean of Normal2: 1.5