Skip to content

Commit

Permalink
[MRG+2] Fixed n**2 memory blowup in _labels_inertia_precompute_dense (s…
Browse files Browse the repository at this point in the history
  • Loading branch information
Erotemic authored and jnothman committed Oct 27, 2016
1 parent 94c2094 commit 061803c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
6 changes: 6 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ New features
Enhancements
............

- :class:`cluster.MiniBatchKMeans` and :class:`cluster.KMeans`
now uses significantly less memory when assigning data points to their
nearest cluster center.
(`#7721 <https://github.com/scikit-learn/scikit-learn/pull/7721>`_)
By `Jon Crall`_.

- Added ``classes_`` attribute to :class:`model_selection.GridSearchCV`
that matches the ``classes_`` attribute of ``best_estimator_``. (`#7661
<https://github.com/scikit-learn/scikit-learn/pull/7661>`_) by `Alyssa
Expand Down
20 changes: 9 additions & 11 deletions sklearn/cluster/k_means_.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ..base import BaseEstimator, ClusterMixin, TransformerMixin
from ..metrics.pairwise import euclidean_distances
from ..metrics.pairwise import pairwise_distances_argmin_min
from ..utils.extmath import row_norms, squared_norm, stable_cumsum
from ..utils.sparsefuncs_fast import assign_rows_csr
from ..utils.sparsefuncs import mean_variance_axis
Expand Down Expand Up @@ -552,17 +553,14 @@ def _labels_inertia_precompute_dense(X, x_squared_norms, centers, distances):
"""
n_samples = X.shape[0]
k = centers.shape[0]
all_distances = euclidean_distances(centers, X, x_squared_norms,
squared=True)
labels = np.empty(n_samples, dtype=np.int32)
labels.fill(-1)
mindist = np.empty(n_samples)
mindist.fill(np.infty)
for center_id in range(k):
dist = all_distances[center_id]
labels[dist < mindist] = center_id
mindist = np.minimum(dist, mindist)

# Breakup nearest neighbor distance computation into batches to prevent
# memory blowup in the case of a large number of samples and clusters.
# TODO: Once PR #7383 is merged use check_inputs=False in metric_kwargs.
labels, mindist = pairwise_distances_argmin_min(
X=X, Y=centers, metric='euclidean', metric_kwargs={'squared': True})
# cython k-means code assumes int32 inputs
labels = labels.astype(np.int32)
if n_samples == distances.shape[0]:
# distances will be changed in-place
distances[:] = mindist
Expand Down

0 comments on commit 061803c

Please sign in to comment.