Skip to content

Commit

Permalink
[MRG+2] Modification of GaussianMixture class. (scikit-learn#7123)
Browse files Browse the repository at this point in the history
* Modification of GaussianMixture class.

The purpose here is to prepare the integration of BayesianGaussianMixture.

* Fix comments.

* Modification of the Docstring.

* Add license and author.

* Fix review and add tests for init.
  • Loading branch information
tguillemot authored and TomDLT committed Oct 3, 2016
1 parent 13952cf commit 666fdff
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 124 deletions.
44 changes: 24 additions & 20 deletions sklearn/mixture/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Author: Wei Xue <[email protected]>
# Modified by Thierry Guillemot <[email protected]>
# License: BSD 3 clause

from __future__ import print_function

Expand Down Expand Up @@ -136,7 +137,7 @@ def _initialize_parameters(self, X):
----------
X : array-like, shape (n_samples, n_features)
"""
n_samples = X.shape[0]
n_samples, _ = X.shape
random_state = check_random_state(self.random_state)

if self.init_params == 'kmeans':
Expand All @@ -145,7 +146,7 @@ def _initialize_parameters(self, X):
random_state=random_state).fit(X).labels_
resp[np.arange(n_samples), label] = 1
elif self.init_params == 'random':
resp = random_state.rand(X.shape[0], self.n_components)
resp = random_state.rand(n_samples, self.n_components)
resp /= resp.sum(axis=1)[:, np.newaxis]
else:
raise ValueError("Unimplemented initialization method '%s'"
Expand Down Expand Up @@ -191,32 +192,36 @@ def fit(self, X, y=None):
do_init = not(self.warm_start and hasattr(self, 'converged_'))
n_init = self.n_init if do_init else 1

max_log_likelihood = -np.infty
max_lower_bound = -np.infty
self.converged_ = False

n_samples, _ = X.shape
for init in range(n_init):
self._print_verbose_msg_init_beg(init)

if do_init:
self._initialize_parameters(X)
current_log_likelihood, resp = self._e_step(X)
self.lower_bound_ = np.infty

for n_iter in range(self.max_iter):
prev_log_likelihood = current_log_likelihood
prev_lower_bound = self.lower_bound_

self._m_step(X, resp)
current_log_likelihood, resp = self._e_step(X)
change = current_log_likelihood - prev_log_likelihood
log_prob_norm, log_resp = self._e_step(X)
self._m_step(X, log_resp)
self.lower_bound_ = self._compute_lower_bound(
log_resp, log_prob_norm)

change = self.lower_bound_ - prev_lower_bound
self._print_verbose_msg_iter_end(n_iter, change)

if abs(change) < self.tol:
self.converged_ = True
break

self._print_verbose_msg_init_end(current_log_likelihood)
self._print_verbose_msg_init_end(self.lower_bound_)

if current_log_likelihood > max_log_likelihood:
max_log_likelihood = current_log_likelihood
if self.lower_bound_ > max_lower_bound:
max_lower_bound = self.lower_bound_
best_params = self._get_parameters()
best_n_iter = n_iter

Expand All @@ -242,21 +247,23 @@ def _e_step(self, X):
Returns
-------
log-likelihood : scalar
log_prob_norm : array, shape (n_samples,)
log p(X)
responsibility : array, shape (n_samples, n_components)
log_responsibility : array, shape (n_samples, n_components)
logarithm of the responsibilities
"""
pass

@abstractmethod
def _m_step(self, X, resp):
def _m_step(self, X, log_resp):
"""M step.
Parameters
----------
X : array-like, shape (n_samples, n_features)
resp : array-like, shape (n_samples, n_components)
log_resp : array-like, shape (n_samples, n_components)
"""
pass

Expand Down Expand Up @@ -342,7 +349,7 @@ def predict_proba(self, X):
"""
self._check_is_fitted()
X = _check_X(X, None, self.means_.shape[1])
_, _, log_resp = self._estimate_log_prob_resp(X)
_, log_resp = self._estimate_log_prob_resp(X)
return np.exp(log_resp)

def _estimate_weighted_log_prob(self, X):
Expand Down Expand Up @@ -400,9 +407,6 @@ def _estimate_log_prob_resp(self, X):
log_prob_norm : array, shape (n_samples,)
log p(X)
log_prob : array, shape (n_samples, n_components)
log p(X|Z) + log weights
log_responsibilities : array, shape (n_samples, n_components)
logarithm of the responsibilities
"""
Expand All @@ -411,7 +415,7 @@ def _estimate_log_prob_resp(self, X):
with np.errstate(under='ignore'):
# ignore underflow
log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis]
return log_prob_norm, weighted_log_prob, log_resp
return log_prob_norm, log_resp

def _print_verbose_msg_init_beg(self, n_init):
"""Print verbose message on initialization."""
Expand Down
165 changes: 80 additions & 85 deletions sklearn/mixture/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Author: Wei Xue <[email protected]>
# Modified by Thierry Guillemot <[email protected]>
# License: BSD 3 clause

import numpy as np

Expand All @@ -11,6 +12,7 @@
from ..externals.six.moves import zip
from ..utils import check_array
from ..utils.validation import check_is_fitted
from ..utils.extmath import row_norms


###############################################################################
Expand Down Expand Up @@ -336,110 +338,99 @@ def _compute_precision_cholesky(covariances, covariance_type):

###############################################################################
# Gaussian mixture probability estimators
def _estimate_log_gaussian_prob_full(X, means, precisions_chol):
"""Estimate the log Gaussian probability for 'full' precision.
def _compute_log_det_cholesky(matrix_chol, covariance_type, n_features):
"""Compute the log-det of the cholesky decomposition of matrices.
Parameters
----------
X : array-like, shape (n_samples, n_features)
matrix_chol : array-like,
Cholesky decompositions of the matrices.
'full' : shape of (n_components, n_features, n_features)
'tied' : shape of (n_features, n_features)
'diag' : shape of (n_components, n_features)
'spherical' : shape of (n_components,)
means : array-like, shape (n_components, n_features)
covariance_type : {'full', 'tied', 'diag', 'spherical'}
precisions_chol : array-like, shape (n_components, n_features, n_features)
Cholesky decompositions of the precision matrices.
n_features : int
Number of features.
Returns
-------
log_prob : array, shape (n_samples, n_components)
log_det_precision_chol : array-like, shape (n_components,)
The determinant of the cholesky decomposition.
matrix.
"""
n_samples, n_features = X.shape
n_components, _ = means.shape
log_prob = np.empty((n_samples, n_components))
for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)):
log_det = -2. * np.sum(np.log(np.diagonal(prec_chol)))
y = np.dot(X - mu, prec_chol)
log_prob[:, k] = -.5 * (n_features * np.log(2. * np.pi) + log_det +
np.sum(np.square(y), axis=1))
return log_prob


def _estimate_log_gaussian_prob_tied(X, means, precision_chol):
"""Estimate the log Gaussian probability for 'tied' precision.
if covariance_type == 'full':
n_components, _, _ = matrix_chol.shape
log_det_chol = (np.sum(np.log(
matrix_chol.reshape(
n_components, -1)[:, ::n_features + 1]), 1))

Parameters
----------
X : array-like, shape (n_samples, n_features)
elif covariance_type == 'tied':
log_det_chol = (np.sum(np.log(np.diag(matrix_chol))))

means : array-like, shape (n_components, n_features)
elif covariance_type == 'diag':
log_det_chol = (np.sum(np.log(matrix_chol), axis=1))

precision_chol : array-like, shape (n_features, n_features)
Cholesky decomposition of the precision matrix.
else:
log_det_chol = n_features * (np.log(matrix_chol))

Returns
-------
log_prob : array-like, shape (n_samples, n_components)
"""
n_samples, n_features = X.shape
n_components, _ = means.shape
log_prob = np.empty((n_samples, n_components))
log_det = -2. * np.sum(np.log(np.diagonal(precision_chol)))
for k, mu in enumerate(means):
y = np.dot(X - mu, precision_chol)
log_prob[:, k] = np.sum(np.square(y), axis=1)
log_prob = -.5 * (n_features * np.log(2. * np.pi) + log_det + log_prob)
return log_prob
return log_det_chol


def _estimate_log_gaussian_prob_diag(X, means, precisions_chol):
"""Estimate the log Gaussian probability for 'diag' precision.
def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type):
"""Estimate the log Gaussian probability.
Parameters
----------
X : array-like, shape (n_samples, n_features)
means : array-like, shape (n_components, n_features)
precisions_chol : array-like, shape (n_components, n_features)
precisions_chol : array-like,
Cholesky decompositions of the precision matrices.
'full' : shape of (n_components, n_features, n_features)
'tied' : shape of (n_features, n_features)
'diag' : shape of (n_components, n_features)
'spherical' : shape of (n_components,)
Returns
-------
log_prob : array-like, shape (n_samples, n_components)
"""
n_samples, n_features = X.shape
precisions = precisions_chol ** 2
log_prob = -.5 * (n_features * np.log(2. * np.pi) -
np.sum(np.log(precisions), 1) +
np.sum((means ** 2 * precisions), 1) -
2. * np.dot(X, (means * precisions).T) +
np.dot(X ** 2, precisions.T))
return log_prob


def _estimate_log_gaussian_prob_spherical(X, means, precisions_chol):
"""Estimate the log Gaussian probability for 'spherical' precision.
Parameters
----------
X : array-like, shape (n_samples, n_features)
means : array-like, shape (n_components, n_features)
precisions_chol : array-like, shape (n_components, )
Cholesky decompositions of the precision matrices.
covariance_type : {'full', 'tied', 'diag', 'spherical'}
Returns
-------
log_prob : array-like, shape (n_samples, n_components)
log_prob : array, shape (n_samples, n_components)
"""
n_samples, n_features = X.shape
precisions = precisions_chol ** 2
log_prob = -.5 * (n_features * np.log(2 * np.pi) -
n_features * np.log(precisions) +
np.sum(means ** 2, 1) * precisions -
2 * np.dot(X, means.T * precisions) +
np.outer(np.sum(X ** 2, axis=1), precisions))
return log_prob
n_components, _ = means.shape
# det(precision_chol) is half of det(precision)
log_det = _compute_log_det_cholesky(
precisions_chol, covariance_type, n_features)

if covariance_type == 'full':
log_prob = np.empty((n_samples, n_components))
for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)):
y = np.dot(X, prec_chol) - np.dot(mu, prec_chol)
log_prob[:, k] = np.sum(np.square(y), axis=1)

elif covariance_type == 'tied':
log_prob = np.empty((n_samples, n_components))
for k, mu in enumerate(means):
y = np.dot(X, precisions_chol) - np.dot(mu, precisions_chol)
log_prob[:, k] = np.sum(np.square(y), axis=1)

elif covariance_type == 'diag':
precisions = precisions_chol ** 2
log_prob = (np.sum((means ** 2 * precisions), 1) -
2. * np.dot(X, (means * precisions).T) +
np.dot(X ** 2, precisions.T))

elif covariance_type == 'spherical':
precisions = precisions_chol ** 2
log_prob = (np.sum(means ** 2, 1) * precisions -
2 * np.dot(X, means.T * precisions) +
np.outer(row_norms(X, squared=True), precisions))
return -.5 * (n_features * np.log(2 * np.pi) + log_prob) + log_det


class GaussianMixture(BaseMixture):
Expand Down Expand Up @@ -475,7 +466,7 @@ class GaussianMixture(BaseMixture):
The number of EM iterations to perform.
n_init : int, defaults to 1.
The number of initializations to perform. The best results is kept.
The number of initializations to perform. The best results are kept.
init_params : {'kmeans', 'random'}, defaults to 'kmeans'.
The method used to initialize the weights, the means and the
Expand Down Expand Up @@ -563,6 +554,9 @@ class GaussianMixture(BaseMixture):
n_iter_ : int
Number of step used by the best fit of EM to reach the convergence.
lower_bound_ : float
Log-likelihood of the best fit of EM.
"""

def __init__(self, n_components=1, covariance_type='full', tol=1e-3,
Expand Down Expand Up @@ -638,7 +632,7 @@ def _initialize(self, X, resp):
self.precisions_cholesky_ = self.precisions_init

def _e_step(self, X):
log_prob_norm, _, log_resp = self._estimate_log_prob_resp(X)
log_prob_norm, log_resp = self._estimate_log_prob_resp(X)
return np.mean(log_prob_norm), np.exp(log_resp)

def _m_step(self, X, resp):
Expand All @@ -651,24 +645,25 @@ def _m_step(self, X, resp):
self.covariances_, self.covariance_type)

def _estimate_log_prob(self, X):
return {"full": _estimate_log_gaussian_prob_full,
"tied": _estimate_log_gaussian_prob_tied,
"diag": _estimate_log_gaussian_prob_diag,
"spherical": _estimate_log_gaussian_prob_spherical
}[self.covariance_type](X, self.means_,
self.precisions_cholesky_)
return _estimate_log_gaussian_prob(
X, self.means_, self.precisions_cholesky_, self.covariance_type)

def _estimate_log_weights(self):
return np.log(self.weights_)

def _compute_lower_bound(self, _, log_prob_norm):
return log_prob_norm

def _check_is_fitted(self):
check_is_fitted(self, ['weights_', 'means_', 'precisions_cholesky_'])

def _get_parameters(self):
return self.weights_, self.means_, self.precisions_cholesky_
return (self.weights_, self.means_, self.covariances_,
self.precisions_cholesky_)

def _set_parameters(self, params):
self.weights_, self.means_, self.precisions_cholesky_ = params
(self.weights_, self.means_, self.covariances_,
self.precisions_cholesky_) = params

# Attributes computation
_, n_features = self.means_.shape
Expand Down
Loading

0 comments on commit 666fdff

Please sign in to comment.