forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MRG+2] Modification of GaussianMixture class. (scikit-learn#7123)
* Modification of GaussianMixture class. The purpose here is to prepare the integration of BayesianGaussianMixture. * Fix comments. * Modification of the Docstring. * Add license and author. * Fix review and add tests for init.
- Loading branch information
1 parent
13952cf
commit 666fdff
Showing
3 changed files
with
152 additions
and
124 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
# Author: Wei Xue <[email protected]> | ||
# Modified by Thierry Guillemot <[email protected]> | ||
# License: BSD 3 clause | ||
|
||
from __future__ import print_function | ||
|
||
|
@@ -136,7 +137,7 @@ def _initialize_parameters(self, X): | |
---------- | ||
X : array-like, shape (n_samples, n_features) | ||
""" | ||
n_samples = X.shape[0] | ||
n_samples, _ = X.shape | ||
random_state = check_random_state(self.random_state) | ||
|
||
if self.init_params == 'kmeans': | ||
|
@@ -145,7 +146,7 @@ def _initialize_parameters(self, X): | |
random_state=random_state).fit(X).labels_ | ||
resp[np.arange(n_samples), label] = 1 | ||
elif self.init_params == 'random': | ||
resp = random_state.rand(X.shape[0], self.n_components) | ||
resp = random_state.rand(n_samples, self.n_components) | ||
resp /= resp.sum(axis=1)[:, np.newaxis] | ||
else: | ||
raise ValueError("Unimplemented initialization method '%s'" | ||
|
@@ -191,32 +192,36 @@ def fit(self, X, y=None): | |
do_init = not(self.warm_start and hasattr(self, 'converged_')) | ||
n_init = self.n_init if do_init else 1 | ||
|
||
max_log_likelihood = -np.infty | ||
max_lower_bound = -np.infty | ||
self.converged_ = False | ||
|
||
n_samples, _ = X.shape | ||
for init in range(n_init): | ||
self._print_verbose_msg_init_beg(init) | ||
|
||
if do_init: | ||
self._initialize_parameters(X) | ||
current_log_likelihood, resp = self._e_step(X) | ||
self.lower_bound_ = np.infty | ||
|
||
for n_iter in range(self.max_iter): | ||
prev_log_likelihood = current_log_likelihood | ||
prev_lower_bound = self.lower_bound_ | ||
|
||
self._m_step(X, resp) | ||
current_log_likelihood, resp = self._e_step(X) | ||
change = current_log_likelihood - prev_log_likelihood | ||
log_prob_norm, log_resp = self._e_step(X) | ||
self._m_step(X, log_resp) | ||
self.lower_bound_ = self._compute_lower_bound( | ||
log_resp, log_prob_norm) | ||
|
||
change = self.lower_bound_ - prev_lower_bound | ||
self._print_verbose_msg_iter_end(n_iter, change) | ||
|
||
if abs(change) < self.tol: | ||
self.converged_ = True | ||
break | ||
|
||
self._print_verbose_msg_init_end(current_log_likelihood) | ||
self._print_verbose_msg_init_end(self.lower_bound_) | ||
|
||
if current_log_likelihood > max_log_likelihood: | ||
max_log_likelihood = current_log_likelihood | ||
if self.lower_bound_ > max_lower_bound: | ||
max_lower_bound = self.lower_bound_ | ||
best_params = self._get_parameters() | ||
best_n_iter = n_iter | ||
|
||
|
@@ -242,21 +247,23 @@ def _e_step(self, X): | |
Returns | ||
------- | ||
log-likelihood : scalar | ||
log_prob_norm : array, shape (n_samples,) | ||
log p(X) | ||
responsibility : array, shape (n_samples, n_components) | ||
log_responsibility : array, shape (n_samples, n_components) | ||
logarithm of the responsibilities | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def _m_step(self, X, resp): | ||
def _m_step(self, X, log_resp): | ||
"""M step. | ||
Parameters | ||
---------- | ||
X : array-like, shape (n_samples, n_features) | ||
resp : array-like, shape (n_samples, n_components) | ||
log_resp : array-like, shape (n_samples, n_components) | ||
""" | ||
pass | ||
|
||
|
@@ -342,7 +349,7 @@ def predict_proba(self, X): | |
""" | ||
self._check_is_fitted() | ||
X = _check_X(X, None, self.means_.shape[1]) | ||
_, _, log_resp = self._estimate_log_prob_resp(X) | ||
_, log_resp = self._estimate_log_prob_resp(X) | ||
return np.exp(log_resp) | ||
|
||
def _estimate_weighted_log_prob(self, X): | ||
|
@@ -400,9 +407,6 @@ def _estimate_log_prob_resp(self, X): | |
log_prob_norm : array, shape (n_samples,) | ||
log p(X) | ||
log_prob : array, shape (n_samples, n_components) | ||
log p(X|Z) + log weights | ||
log_responsibilities : array, shape (n_samples, n_components) | ||
logarithm of the responsibilities | ||
""" | ||
|
@@ -411,7 +415,7 @@ def _estimate_log_prob_resp(self, X): | |
with np.errstate(under='ignore'): | ||
# ignore underflow | ||
log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis] | ||
return log_prob_norm, weighted_log_prob, log_resp | ||
return log_prob_norm, log_resp | ||
|
||
def _print_verbose_msg_init_beg(self, n_init): | ||
"""Print verbose message on initialization.""" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
# Author: Wei Xue <[email protected]> | ||
# Modified by Thierry Guillemot <[email protected]> | ||
# License: BSD 3 clause | ||
|
||
import numpy as np | ||
|
||
|
@@ -11,6 +12,7 @@ | |
from ..externals.six.moves import zip | ||
from ..utils import check_array | ||
from ..utils.validation import check_is_fitted | ||
from ..utils.extmath import row_norms | ||
|
||
|
||
############################################################################### | ||
|
@@ -336,110 +338,99 @@ def _compute_precision_cholesky(covariances, covariance_type): | |
|
||
############################################################################### | ||
# Gaussian mixture probability estimators | ||
def _estimate_log_gaussian_prob_full(X, means, precisions_chol): | ||
"""Estimate the log Gaussian probability for 'full' precision. | ||
def _compute_log_det_cholesky(matrix_chol, covariance_type, n_features): | ||
"""Compute the log-det of the cholesky decomposition of matrices. | ||
Parameters | ||
---------- | ||
X : array-like, shape (n_samples, n_features) | ||
matrix_chol : array-like, | ||
Cholesky decompositions of the matrices. | ||
'full' : shape of (n_components, n_features, n_features) | ||
'tied' : shape of (n_features, n_features) | ||
'diag' : shape of (n_components, n_features) | ||
'spherical' : shape of (n_components,) | ||
means : array-like, shape (n_components, n_features) | ||
covariance_type : {'full', 'tied', 'diag', 'spherical'} | ||
precisions_chol : array-like, shape (n_components, n_features, n_features) | ||
Cholesky decompositions of the precision matrices. | ||
n_features : int | ||
Number of features. | ||
Returns | ||
------- | ||
log_prob : array, shape (n_samples, n_components) | ||
log_det_precision_chol : array-like, shape (n_components,) | ||
The determinant of the cholesky decomposition. | ||
matrix. | ||
""" | ||
n_samples, n_features = X.shape | ||
n_components, _ = means.shape | ||
log_prob = np.empty((n_samples, n_components)) | ||
for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)): | ||
log_det = -2. * np.sum(np.log(np.diagonal(prec_chol))) | ||
y = np.dot(X - mu, prec_chol) | ||
log_prob[:, k] = -.5 * (n_features * np.log(2. * np.pi) + log_det + | ||
np.sum(np.square(y), axis=1)) | ||
return log_prob | ||
|
||
|
||
def _estimate_log_gaussian_prob_tied(X, means, precision_chol): | ||
"""Estimate the log Gaussian probability for 'tied' precision. | ||
if covariance_type == 'full': | ||
n_components, _, _ = matrix_chol.shape | ||
log_det_chol = (np.sum(np.log( | ||
matrix_chol.reshape( | ||
n_components, -1)[:, ::n_features + 1]), 1)) | ||
|
||
Parameters | ||
---------- | ||
X : array-like, shape (n_samples, n_features) | ||
elif covariance_type == 'tied': | ||
log_det_chol = (np.sum(np.log(np.diag(matrix_chol)))) | ||
|
||
means : array-like, shape (n_components, n_features) | ||
elif covariance_type == 'diag': | ||
log_det_chol = (np.sum(np.log(matrix_chol), axis=1)) | ||
|
||
precision_chol : array-like, shape (n_features, n_features) | ||
Cholesky decomposition of the precision matrix. | ||
else: | ||
log_det_chol = n_features * (np.log(matrix_chol)) | ||
|
||
Returns | ||
------- | ||
log_prob : array-like, shape (n_samples, n_components) | ||
""" | ||
n_samples, n_features = X.shape | ||
n_components, _ = means.shape | ||
log_prob = np.empty((n_samples, n_components)) | ||
log_det = -2. * np.sum(np.log(np.diagonal(precision_chol))) | ||
for k, mu in enumerate(means): | ||
y = np.dot(X - mu, precision_chol) | ||
log_prob[:, k] = np.sum(np.square(y), axis=1) | ||
log_prob = -.5 * (n_features * np.log(2. * np.pi) + log_det + log_prob) | ||
return log_prob | ||
return log_det_chol | ||
|
||
|
||
def _estimate_log_gaussian_prob_diag(X, means, precisions_chol): | ||
"""Estimate the log Gaussian probability for 'diag' precision. | ||
def _estimate_log_gaussian_prob(X, means, precisions_chol, covariance_type): | ||
"""Estimate the log Gaussian probability. | ||
Parameters | ||
---------- | ||
X : array-like, shape (n_samples, n_features) | ||
means : array-like, shape (n_components, n_features) | ||
precisions_chol : array-like, shape (n_components, n_features) | ||
precisions_chol : array-like, | ||
Cholesky decompositions of the precision matrices. | ||
'full' : shape of (n_components, n_features, n_features) | ||
'tied' : shape of (n_features, n_features) | ||
'diag' : shape of (n_components, n_features) | ||
'spherical' : shape of (n_components,) | ||
Returns | ||
------- | ||
log_prob : array-like, shape (n_samples, n_components) | ||
""" | ||
n_samples, n_features = X.shape | ||
precisions = precisions_chol ** 2 | ||
log_prob = -.5 * (n_features * np.log(2. * np.pi) - | ||
np.sum(np.log(precisions), 1) + | ||
np.sum((means ** 2 * precisions), 1) - | ||
2. * np.dot(X, (means * precisions).T) + | ||
np.dot(X ** 2, precisions.T)) | ||
return log_prob | ||
|
||
|
||
def _estimate_log_gaussian_prob_spherical(X, means, precisions_chol): | ||
"""Estimate the log Gaussian probability for 'spherical' precision. | ||
Parameters | ||
---------- | ||
X : array-like, shape (n_samples, n_features) | ||
means : array-like, shape (n_components, n_features) | ||
precisions_chol : array-like, shape (n_components, ) | ||
Cholesky decompositions of the precision matrices. | ||
covariance_type : {'full', 'tied', 'diag', 'spherical'} | ||
Returns | ||
------- | ||
log_prob : array-like, shape (n_samples, n_components) | ||
log_prob : array, shape (n_samples, n_components) | ||
""" | ||
n_samples, n_features = X.shape | ||
precisions = precisions_chol ** 2 | ||
log_prob = -.5 * (n_features * np.log(2 * np.pi) - | ||
n_features * np.log(precisions) + | ||
np.sum(means ** 2, 1) * precisions - | ||
2 * np.dot(X, means.T * precisions) + | ||
np.outer(np.sum(X ** 2, axis=1), precisions)) | ||
return log_prob | ||
n_components, _ = means.shape | ||
# det(precision_chol) is half of det(precision) | ||
log_det = _compute_log_det_cholesky( | ||
precisions_chol, covariance_type, n_features) | ||
|
||
if covariance_type == 'full': | ||
log_prob = np.empty((n_samples, n_components)) | ||
for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)): | ||
y = np.dot(X, prec_chol) - np.dot(mu, prec_chol) | ||
log_prob[:, k] = np.sum(np.square(y), axis=1) | ||
|
||
elif covariance_type == 'tied': | ||
log_prob = np.empty((n_samples, n_components)) | ||
for k, mu in enumerate(means): | ||
y = np.dot(X, precisions_chol) - np.dot(mu, precisions_chol) | ||
log_prob[:, k] = np.sum(np.square(y), axis=1) | ||
|
||
elif covariance_type == 'diag': | ||
precisions = precisions_chol ** 2 | ||
log_prob = (np.sum((means ** 2 * precisions), 1) - | ||
2. * np.dot(X, (means * precisions).T) + | ||
np.dot(X ** 2, precisions.T)) | ||
|
||
elif covariance_type == 'spherical': | ||
precisions = precisions_chol ** 2 | ||
log_prob = (np.sum(means ** 2, 1) * precisions - | ||
2 * np.dot(X, means.T * precisions) + | ||
np.outer(row_norms(X, squared=True), precisions)) | ||
return -.5 * (n_features * np.log(2 * np.pi) + log_prob) + log_det | ||
|
||
|
||
class GaussianMixture(BaseMixture): | ||
|
@@ -475,7 +466,7 @@ class GaussianMixture(BaseMixture): | |
The number of EM iterations to perform. | ||
n_init : int, defaults to 1. | ||
The number of initializations to perform. The best results is kept. | ||
The number of initializations to perform. The best results are kept. | ||
init_params : {'kmeans', 'random'}, defaults to 'kmeans'. | ||
The method used to initialize the weights, the means and the | ||
|
@@ -563,6 +554,9 @@ class GaussianMixture(BaseMixture): | |
n_iter_ : int | ||
Number of step used by the best fit of EM to reach the convergence. | ||
lower_bound_ : float | ||
Log-likelihood of the best fit of EM. | ||
""" | ||
|
||
def __init__(self, n_components=1, covariance_type='full', tol=1e-3, | ||
|
@@ -638,7 +632,7 @@ def _initialize(self, X, resp): | |
self.precisions_cholesky_ = self.precisions_init | ||
|
||
def _e_step(self, X): | ||
log_prob_norm, _, log_resp = self._estimate_log_prob_resp(X) | ||
log_prob_norm, log_resp = self._estimate_log_prob_resp(X) | ||
return np.mean(log_prob_norm), np.exp(log_resp) | ||
|
||
def _m_step(self, X, resp): | ||
|
@@ -651,24 +645,25 @@ def _m_step(self, X, resp): | |
self.covariances_, self.covariance_type) | ||
|
||
def _estimate_log_prob(self, X): | ||
return {"full": _estimate_log_gaussian_prob_full, | ||
"tied": _estimate_log_gaussian_prob_tied, | ||
"diag": _estimate_log_gaussian_prob_diag, | ||
"spherical": _estimate_log_gaussian_prob_spherical | ||
}[self.covariance_type](X, self.means_, | ||
self.precisions_cholesky_) | ||
return _estimate_log_gaussian_prob( | ||
X, self.means_, self.precisions_cholesky_, self.covariance_type) | ||
|
||
def _estimate_log_weights(self): | ||
return np.log(self.weights_) | ||
|
||
def _compute_lower_bound(self, _, log_prob_norm): | ||
return log_prob_norm | ||
|
||
def _check_is_fitted(self): | ||
check_is_fitted(self, ['weights_', 'means_', 'precisions_cholesky_']) | ||
|
||
def _get_parameters(self): | ||
return self.weights_, self.means_, self.precisions_cholesky_ | ||
return (self.weights_, self.means_, self.covariances_, | ||
self.precisions_cholesky_) | ||
|
||
def _set_parameters(self, params): | ||
self.weights_, self.means_, self.precisions_cholesky_ = params | ||
(self.weights_, self.means_, self.covariances_, | ||
self.precisions_cholesky_) = params | ||
|
||
# Attributes computation | ||
_, n_features = self.means_.shape | ||
|
Oops, something went wrong.