From a9c6ad9baf878015653569109091828ceaf2db8e Mon Sep 17 00:00:00 2001 From: Adrin Jalali Date: Wed, 5 Sep 2018 17:55:07 +0200 Subject: [PATCH] [MRG+1] break the tie in Meanshift in case cluster intensities are the same (#11901) --- doc/whats_new/v0.20.rst | 6 ++++++ sklearn/cluster/mean_shift_.py | 8 +++++--- sklearn/cluster/tests/test_mean_shift.py | 12 ++++++++++++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 2ed336b782174..46b262896145c 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -63,6 +63,7 @@ parameters, may produce different models from the previous version. This often occurs due to changes in the modelling logic (bug fixes or enhancements), or in random sampling procedures. +- :class:`cluster.MeanShift` (bug fix) - :class:`decomposition.IncrementalPCA` in Python 2 (bug fix) - :class:`decomposition.SparsePCA` (bug fix) - :class:`ensemble.GradientBoostingClassifier` (bug fix affecting feature importances) @@ -151,6 +152,11 @@ Support for Python 3.3 has been officially dropped. ``n_iter_`` attribute in the docstring of :class:`cluster.KMeans`. :issue:`11353` by :user:`Jeremie du Boisberranger `. +- |Fix| Fixed a bug in :func:`cluster.mean_shift` where the assigned labels + were not deterministic if there were multiple clusters with the same + intensities. + :issue:`11901` by :user:`Adrin Jalali `. + - |API| Deprecate ``pooling_func`` unused parameter in :class:`cluster.AgglomerativeClustering`. :issue:`9875` by :user:`Kumar Ashutosh `. diff --git a/sklearn/cluster/mean_shift_.py b/sklearn/cluster/mean_shift_.py index 487545ac039d3..800c85c365988 100644 --- a/sklearn/cluster/mean_shift_.py +++ b/sklearn/cluster/mean_shift_.py @@ -215,8 +215,10 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, # If the distance between two kernels is less than the bandwidth, # then we have to remove one because it is a duplicate. Remove the # one with fewer points. + sorted_by_intensity = sorted(center_intensity_dict.items(), - key=lambda tup: tup[1], reverse=True) + key=lambda tup: (tup[1], tup[0]), + reverse=True) sorted_centers = np.array([tup[0] for tup in sorted_by_intensity]) unique = np.ones(len(sorted_centers), dtype=np.bool) nbrs = NearestNeighbors(radius=bandwidth, @@ -359,9 +361,9 @@ class MeanShift(BaseEstimator, ClusterMixin): ... [4, 7], [3, 5], [3, 6]]) >>> clustering = MeanShift(bandwidth=2).fit(X) >>> clustering.labels_ - array([0, 0, 0, 1, 1, 1]) + array([1, 1, 1, 0, 0, 0]) >>> clustering.predict([[0, 0], [5, 5]]) - array([0, 1]) + array([1, 0]) >>> clustering # doctest: +NORMALIZE_WHITESPACE MeanShift(bandwidth=2, bin_seeding=False, cluster_all=True, min_bin_freq=1, n_jobs=None, seeds=None) diff --git a/sklearn/cluster/tests/test_mean_shift.py b/sklearn/cluster/tests/test_mean_shift.py index 1d6940a947dc2..441f822cdbded 100644 --- a/sklearn/cluster/tests/test_mean_shift.py +++ b/sklearn/cluster/tests/test_mean_shift.py @@ -101,6 +101,18 @@ def test_unfitted(): assert_false(hasattr(ms, "labels_")) +def test_cluster_intensity_tie(): + X = np.array([[1, 1], [2, 1], [1, 0], + [4, 7], [3, 5], [3, 6]]) + c1 = MeanShift(bandwidth=2).fit(X) + + X = np.array([[4, 7], [3, 5], [3, 6], + [1, 1], [2, 1], [1, 0]]) + c2 = MeanShift(bandwidth=2).fit(X) + assert_array_equal(c1.labels_, [1, 1, 1, 0, 0, 0]) + assert_array_equal(c2.labels_, [0, 0, 0, 1, 1, 1]) + + def test_bin_seeds(): # Test the bin seeding technique which can be used in the mean shift # algorithm