Closed AlexanderFabisch closed 3 years ago
Test code:
from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
from sklearn.utils import check_X_y
from sklearn.utils.validation import (check_is_fitted, check_array,
FLOAT_DTYPES)
from sklearn.mixture import GaussianMixture
from gmr.gmm import GMM
class GaussianMixtureRegressorSklearn(MultiOutputMixin, RegressorMixin, BaseEstimator):
def __init__(
self, n_components, verbose=0, random_state=None, R_diff=1e-4,
n_iter=500, init_params="random"):
self.n_components = n_components
self.verbose = verbose
self.random_state = random_state
self.R_diff = R_diff
self.n_iter = n_iter
self.init_params = init_params
def fit(self, X, y):
gmm_ = GaussianMixture(self.n_components, init_params=self.init_params, max_iter=self.n_iter, tol=self.R_diff)
X, y = check_X_y(X, y, estimator=gmm_, dtype=FLOAT_DTYPES,
multi_output=True)
if y.ndim == 1:
y = np.expand_dims(y, 1)
self.indices_ = np.arange(X.shape[1])
gmm_.fit(np.hstack((X, y)))
self.gmm_ = GMM(self.n_components, priors=gmm_.weights_, means=gmm_.means_, covariances=gmm_.covariances_, verbose=self.verbose)
return self
def predict(self, X):
check_is_fitted(self, ["gmm_", "indices_"])
X = check_array(X, estimator=self.gmm_, dtype=FLOAT_DTYPES)
return self.gmm_.predict(self.indices_, X)
import numpy as np
from sklearn.datasets import load_boston
from gmr.sklearn import GaussianMixtureRegressor
X, y = load_boston(return_X_y=True)
np.set_printoptions(precision=2, suppress=True)
np.random.seed(2)
scores = []
for _ in range(10):
gmr = GaussianMixtureRegressor(
n_components=2, verbose=10, R_diff=1e-7, init_params="kmeans++")
gmr.fit(X, y)
score = gmr.score(X, y)
print(f"{score:.2f}")
scores.append(score)
print(np.array(scores))
np.random.seed(2)
scores = []
for _ in range(10):
gmr = GaussianMixtureRegressorSklearn(
n_components=2, verbose=10, R_diff=1e-7, init_params="random")
gmr.fit(X, y)
score = gmr.score(X, y)
print(f"{score:.2f}")
scores.append(score)
print(np.array(scores))
Relevant examples at the and of this pull request: https://github.com/AlexanderFabisch/gmr/pull/28 from @mralbu
init_params="kmeans++"
drastically improves stability of our results, it is not the default initialization though