jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.38k stars 590 forks source link

[BUG] `pomegranate.distributions.LogNormal.summarize()` fails to invoke `super()` #1122

Open levon003 opened 2 days ago

levon003 commented 2 days ago

Bug Description LogNormal.summarize() throws a NoneType exception.

This is because Normal() correctly calls _distribution.Distribution.summarize() via super() before any processing with X:

X, sample_weight = super().summarize(X, sample_weight=sample_weight)

LogNormal() attempts to access self.means before super().summarize() is called.

To Reproduce

import numpy as np
import pomegranate.distributions
x_train = np.array([0, 1, 2, 2, 3, 4, 20, 21, 22, 22, 23, 24]).reshape([-1, 1])
d = pomegranate.distributions.LogNormal()
d.fit(x_train)

produces:

File [.../site-packages/pomegranate/distributions/lognormal.py:174, in LogNormal.summarize(self, X, sample_weight)
    172 if self.frozen is True:
    173     return
--> 174 X = _cast_as_tensor(X, dtype=self.means.dtype)
    175 super().summarize(X.log(), sample_weight=sample_weight)

AttributeError: 'NoneType' object has no attribute 'dtype'

I would cut a PR for this, but I'm not actually sure what the most Pythonic resolution for this kind of inheritance issue is.

A basic workaround is invoking summarize on the distribution manually:

import pomegranate.distributions._distribution
true_mean = 10
x_train = np.random.lognormal(mean=true_mean, sigma=5, size=10000).reshape([-1, 1])
d = pomegranate.distributions.LogNormal()
X, sample_weight = pomegranate.distributions._distribution.Distribution.summarize(d, X, sample_weight=None)
d.fit(x_train)
assert np.isclose(d.means[0], true_mean, atol=0.1)

Edit: a better workaround:

from types import MethodType
import pomegranate.gmm
import pomegranate.distributions
import pomegranate.distributions._distribution
import pomegranate._utils

def fixed_summarize(self, X, sample_weight=None):
    if self.frozen is True:
        return
    X, sample_weight = pomegranate.distributions._distribution.Distribution.summarize(self, X, sample_weight=sample_weight)
    X = pomegranate._utils._cast_as_tensor(X, dtype=self.means.dtype)
    pomegranate.distributions.Normal.summarize(self, X.log(), sample_weight=sample_weight)

x_train = np.array([1, 2, 2, 3, 4, 20, 21, 22, 22, 23, 24, 10000, 10001, 10002]).reshape([-1, 1])
components = []
for i in range(2):
    d = pomegranate.distributions.LogNormal()
    d.summarize = MethodType(fixed_summarize, d)
    components.append(d)
model = pomegranate.gmm.GeneralMixtureModel(components, tol=0.01).fit(x_train)
[d.means for d in model.distributions]