From d01cdc204ee4972307aca4cc1e1b1e5e6347cc70 Mon Sep 17 00:00:00 2001 From: Jan Margeta Date: Mon, 18 Dec 2017 10:20:53 +0100 Subject: [PATCH] Fix spectral embedding implementation (#9062) --- .travis.yml | 4 +- build_tools/travis/install.sh | 27 +++-- doc/whats_new/v0.20.rst | 11 +- examples/cluster/plot_face_segmentation.py | 2 +- sklearn/cluster/spectral.py | 4 + sklearn/cluster/tests/test_spectral.py | 103 ++++++++++-------- sklearn/manifold/spectral_embedding_.py | 18 ++- .../manifold/tests/test_spectral_embedding.py | 27 ++++- 8 files changed, 125 insertions(+), 71 deletions(-) diff --git a/.travis.yml b/.travis.yml index aea19fc6b36b4..4022f78aa0928 100644 --- a/.travis.yml +++ b/.travis.yml @@ -39,10 +39,10 @@ matrix: COVERAGE=true if: type != cron # This environment tests the newest supported Anaconda release (5.0.0) - # It also runs tests requiring Pandas. + # It also runs tests requiring Pandas and PyAMG - env: DISTRIB="conda" PYTHON_VERSION="3.6.2" INSTALL_MKL="true" NUMPY_VERSION="1.13.1" SCIPY_VERSION="0.19.1" PANDAS_VERSION="0.20.3" - CYTHON_VERSION="0.26.1" COVERAGE=true + CYTHON_VERSION="0.26.1" PYAMG_VERSION="3.3.2" COVERAGE=true CHECK_PYTEST_SOFT_DEPENDENCY="true" if: type != cron # flake8 linting on diff wrt common ancestor with upstream/master diff --git a/build_tools/travis/install.sh b/build_tools/travis/install.sh index 1acfe10280e86..76cd3221cb009 100755 --- a/build_tools/travis/install.sh +++ b/build_tools/travis/install.sh @@ -37,20 +37,25 @@ if [[ "$DISTRIB" == "conda" ]]; then export PATH=$MINICONDA_PATH/bin:$PATH conda update --yes conda - # Configure the conda environment and put it in the path using the - # provided versions + TO_INSTALL="python=$PYTHON_VERSION pip pytest pytest-cov \ + numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION \ + cython=$CYTHON_VERSION" + if [[ "$INSTALL_MKL" == "true" ]]; then - conda create -n testenv --yes python=$PYTHON_VERSION pip \ - pytest pytest-cov numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION \ - mkl cython=$CYTHON_VERSION \ - ${PANDAS_VERSION+pandas=$PANDAS_VERSION} - + TO_INSTALL="$TO_INSTALL mkl" else - conda create -n testenv --yes python=$PYTHON_VERSION pip \ - pytest pytest-cov numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION \ - nomkl cython=$CYTHON_VERSION \ - ${PANDAS_VERSION+pandas=$PANDAS_VERSION} + TO_INSTALL="$TO_INSTALL nomkl" + fi + + if [[ -n "$PANDAS_VERSION" ]]; then + TO_INSTALL="$TO_INSTALL pandas=$PANDAS_VERSION" fi + + if [[ -n "$PYAMG_VERSION" ]]; then + TO_INSTALL="$TO_INSTALL pyamg=$PYAMG_VERSION" + fi + + conda create -n testenv --yes $TO_INSTALL source activate testenv elif [[ "$DISTRIB" == "ubuntu" ]]; then diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index c6d8645160341..6ff4e7b059a70 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -225,16 +225,21 @@ Decomposition, manifold learning and clustering - Fixed a bug when setting parameters on meta-estimator, involving both a wrapped estimator and its parameter. :issue:`9999` by :user:`Marcus Voss ` and `Joel Nothman`_. - + - ``k_means`` now gives a warning, if the number of distinct clusters found - is smaller than ``n_clusters``. This may occur when the number of distinct - points in the data set is actually smaller than the number of cluster one is + is smaller than ``n_clusters``. This may occur when the number of distinct + points in the data set is actually smaller than the number of cluster one is looking for. :issue:`10059` by :user:`Christian Braune `. - Fixed a bug in :func:`datasets.make_circles`, where no odd number of data points could be generated. :issue:`10037` by :user:`Christian Braune `_. +- Fixed a bug in :func:`cluster.spectral_clustering` where the normalization of + the spectrum was using a division instead of a multiplication. :issue:`8129` + by :user:`Jan Margeta `, :user:`Guillaume Lemaitre `, + and :user:`Devansh D. `. + Metrics - Fixed a bug due to floating point error in :func:`metrics.roc_auc_score` with diff --git a/examples/cluster/plot_face_segmentation.py b/examples/cluster/plot_face_segmentation.py index 12b7318b7e338..c67e61d5d37e2 100644 --- a/examples/cluster/plot_face_segmentation.py +++ b/examples/cluster/plot_face_segmentation.py @@ -63,7 +63,7 @@ for assign_labels in ('kmeans', 'discretize'): t0 = time.time() labels = spectral_clustering(graph, n_clusters=N_REGIONS, - assign_labels=assign_labels, random_state=1) + assign_labels=assign_labels, random_state=42) t1 = time.time() labels = labels.reshape(face.shape) diff --git a/sklearn/cluster/spectral.py b/sklearn/cluster/spectral.py index f224098285d44..5051043c31477 100644 --- a/sklearn/cluster/spectral.py +++ b/sklearn/cluster/spectral.py @@ -256,6 +256,10 @@ def spectral_clustering(affinity, n_clusters=8, n_components=None, random_state = check_random_state(random_state) n_components = n_clusters if n_components is None else n_components + + # The first eigen vector is constant only for fully connected graphs + # and should be kept for spectral clustering (drop_first = False) + # See spectral_embedding documentation. maps = spectral_embedding(affinity, n_components=n_components, eigen_solver=eigen_solver, random_state=random_state, diff --git a/sklearn/cluster/tests/test_spectral.py b/sklearn/cluster/tests/test_spectral.py index 48b1a8f32ea38..62d9adcc2e34f 100644 --- a/sklearn/cluster/tests/test_spectral.py +++ b/sklearn/cluster/tests/test_spectral.py @@ -1,27 +1,31 @@ """Testing for Spectral Clustering methods""" - -from sklearn.externals.six.moves import cPickle - -dumps, loads = cPickle.dumps, cPickle.loads +from __future__ import division import numpy as np from scipy import sparse +from sklearn.externals.six.moves import cPickle + from sklearn.utils import check_random_state from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_raises -from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_warns_message from sklearn.cluster import SpectralClustering, spectral_clustering -from sklearn.cluster.spectral import spectral_embedding from sklearn.cluster.spectral import discretize +from sklearn.feature_extraction import img_to_graph from sklearn.metrics import pairwise_distances from sklearn.metrics import adjusted_rand_score from sklearn.metrics.pairwise import kernel_metrics, rbf_kernel from sklearn.datasets.samples_generator import make_blobs +try: + from pyamg import smoothed_aggregation_solver # noqa + amg_loaded = True +except ImportError: + amg_loaded = False + def test_spectral_clustering(): S = np.array([[1.0, 1.0, 1.0, 0.2, 0.0, 0.0, 0.0], @@ -44,44 +48,14 @@ def test_spectral_clustering(): if labels[0] == 0: labels = 1 - labels - assert_array_equal(labels, [1, 1, 1, 0, 0, 0, 0]) + assert adjusted_rand_score(labels, [1, 1, 1, 0, 0, 0, 0]) == 1 - model_copy = loads(dumps(model)) - assert_equal(model_copy.n_clusters, model.n_clusters) - assert_equal(model_copy.eigen_solver, model.eigen_solver) + model_copy = cPickle.loads(cPickle.dumps(model)) + assert model_copy.n_clusters == model.n_clusters + assert model_copy.eigen_solver == model.eigen_solver assert_array_equal(model_copy.labels_, model.labels_) -def test_spectral_amg_mode(): - # Test the amg mode of SpectralClustering - centers = np.array([ - [0., 0., 0.], - [10., 10., 10.], - [20., 20., 20.], - ]) - X, true_labels = make_blobs(n_samples=100, centers=centers, - cluster_std=1., random_state=42) - D = pairwise_distances(X) # Distance matrix - S = np.max(D) - D # Similarity matrix - S = sparse.coo_matrix(S) - try: - from pyamg import smoothed_aggregation_solver # noqa - - amg_loaded = True - except ImportError: - amg_loaded = False - if amg_loaded: - labels = spectral_clustering(S, n_clusters=len(centers), - random_state=0, eigen_solver="amg") - # We don't care too much that it's good, just that it *worked*. - # There does have to be some lower limit on the performance though. - assert_greater(np.mean(labels == true_labels), .3) - else: - assert_raises(ValueError, spectral_embedding, S, - n_components=len(centers), - random_state=0, eigen_solver="amg") - - def test_spectral_unknown_mode(): # Test that SpectralClustering fails with an unknown mode set. centers = np.array([ @@ -124,7 +98,7 @@ def test_spectral_clustering_sparse(): labels = SpectralClustering(random_state=0, n_clusters=2, affinity='precomputed').fit(S).labels_ - assert_equal(adjusted_rand_score(y, labels), 1) + assert adjusted_rand_score(y, labels) == 1 def test_affinities(): @@ -138,11 +112,11 @@ def test_affinities(): sp = SpectralClustering(n_clusters=2, affinity='nearest_neighbors', random_state=0) assert_warns_message(UserWarning, 'not fully connected', sp.fit, X) - assert_equal(adjusted_rand_score(y, sp.labels_), 1) + assert adjusted_rand_score(y, sp.labels_) == 1 sp = SpectralClustering(n_clusters=2, gamma=2, random_state=0) labels = sp.fit(X).labels_ - assert_equal(adjusted_rand_score(y, labels), 1) + assert adjusted_rand_score(y, labels) == 1 X = check_random_state(10).rand(10, 5) * 10 @@ -154,12 +128,12 @@ def test_affinities(): sp = SpectralClustering(n_clusters=2, affinity=kern, random_state=0) labels = sp.fit(X).labels_ - assert_equal((X.shape[0],), labels.shape) + assert (X.shape[0],) == labels.shape sp = SpectralClustering(n_clusters=2, affinity=lambda x, y: 1, random_state=0) labels = sp.fit(X).labels_ - assert_equal((X.shape[0],), labels.shape) + assert (X.shape[0],) == labels.shape def histogram(x, y, **kwargs): # Histogram kernel implemented as a callable. @@ -168,7 +142,7 @@ def histogram(x, y, **kwargs): sp = SpectralClustering(n_clusters=2, affinity=histogram, random_state=0) labels = sp.fit(X).labels_ - assert_equal((X.shape[0],), labels.shape) + assert (X.shape[0],) == labels.shape # raise error on unknown affinity sp = SpectralClustering(n_clusters=2, affinity='') @@ -193,4 +167,39 @@ def test_discretize(seed=8): + 0.1 * random_state.randn(n_samples, n_class + 1)) y_pred = discretize(y_true_noisy, random_state) - assert_greater(adjusted_rand_score(y_true, y_pred), 0.8) + assert adjusted_rand_score(y_true, y_pred) > 0.8 + + +def test_spectral_clustering_with_arpack_amg_solvers(): + # Test that spectral_clustering is the same for arpack and amg solver + # Based on toy example from plot_segmentation_toy.py + + # a small two coin image + x, y = np.indices((40, 40)) + + center1, center2 = (14, 12), (20, 25) + radius1, radius2 = 8, 7 + + circle1 = (x - center1[0]) ** 2 + (y - center1[1]) ** 2 < radius1 ** 2 + circle2 = (x - center2[0]) ** 2 + (y - center2[1]) ** 2 < radius2 ** 2 + + circles = circle1 | circle2 + mask = circles.copy() + img = circles.astype(float) + + graph = img_to_graph(img, mask=mask) + graph.data = np.exp(-graph.data / graph.data.std()) + + labels_arpack = spectral_clustering( + graph, n_clusters=2, eigen_solver='arpack', random_state=0) + + assert len(np.unique(labels_arpack)) == 2 + + if amg_loaded: + labels_amg = spectral_clustering( + graph, n_clusters=2, eigen_solver='amg', random_state=0) + assert adjusted_rand_score(labels_arpack, labels_amg) == 1 + else: + assert_raises( + ValueError, spectral_clustering, + graph, n_clusters=2, eigen_solver='amg', random_state=0) diff --git a/sklearn/manifold/spectral_embedding_.py b/sklearn/manifold/spectral_embedding_.py index c1d0e2e5a75f2..e399c75708fcb 100644 --- a/sklearn/manifold/spectral_embedding_.py +++ b/sklearn/manifold/spectral_embedding_.py @@ -4,6 +4,8 @@ # Wei LI # License: BSD 3 clause +from __future__ import division + import warnings import numpy as np @@ -269,7 +271,9 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, lambdas, diffusion_map = eigsh(laplacian, k=n_components, sigma=1.0, which='LM', tol=eigen_tol, v0=v0) - embedding = diffusion_map.T[n_components::-1] * dd + embedding = diffusion_map.T[n_components::-1] + if norm_laplacian: + embedding = embedding / dd except RuntimeError: # When submatrices are exactly singular, an LU decomposition # in arpack fails. We fallback to lobpcg @@ -292,7 +296,9 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, X[:, 0] = dd.ravel() lambdas, diffusion_map = lobpcg(laplacian, X, M=M, tol=1.e-12, largest=False) - embedding = diffusion_map.T * dd + embedding = diffusion_map.T + if norm_laplacian: + embedding = embedding / dd if embedding.shape[0] == 1: raise ValueError @@ -307,7 +313,9 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, if sparse.isspmatrix(laplacian): laplacian = laplacian.toarray() lambdas, diffusion_map = eigh(laplacian) - embedding = diffusion_map.T[:n_components] * dd + embedding = diffusion_map.T[:n_components] + if norm_laplacian: + embedding = embedding / dd else: laplacian = _set_diag(laplacian, 1, norm_laplacian) # We increase the number of eigenvectors requested, as lobpcg @@ -316,7 +324,9 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, X[:, 0] = dd.ravel() lambdas, diffusion_map = lobpcg(laplacian, X, tol=1e-15, largest=False, maxiter=2000) - embedding = diffusion_map.T[:n_components] * dd + embedding = diffusion_map.T[:n_components] + if norm_laplacian: + embedding = embedding / dd if embedding.shape[0] == 1: raise ValueError diff --git a/sklearn/manifold/tests/test_spectral_embedding.py b/sklearn/manifold/tests/test_spectral_embedding.py index dd746f2af2597..bc32b58c6e7ff 100644 --- a/sklearn/manifold/tests/test_spectral_embedding.py +++ b/sklearn/manifold/tests/test_spectral_embedding.py @@ -1,6 +1,6 @@ +import pytest + import numpy as np -from numpy.testing import assert_array_almost_equal -from numpy.testing import assert_array_equal from scipy import sparse from scipy.sparse import csgraph @@ -15,6 +15,8 @@ from sklearn.cluster import KMeans from sklearn.datasets.samples_generator import make_blobs from sklearn.utils.extmath import _deterministic_vector_sign_flip +from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_true, assert_equal, assert_raises from sklearn.utils.testing import SkipTest @@ -258,7 +260,26 @@ def test_spectral_embedding_unnormalized(): laplacian, dd = csgraph.laplacian(sims, normed=False, return_diag=True) _, diffusion_map = eigh(laplacian) - embedding_2 = diffusion_map.T[:n_components] * dd + embedding_2 = diffusion_map.T[:n_components] embedding_2 = _deterministic_vector_sign_flip(embedding_2).T assert_array_almost_equal(embedding_1, embedding_2) + + +def test_spectral_embedding_first_eigen_vector(): + # Test that the first eigenvector of spectral_embedding + # is constant and that the second is not (for a connected graph) + random_state = np.random.RandomState(36) + data = random_state.randn(10, 30) + sims = rbf_kernel(data) + n_components = 2 + + for seed in range(10): + embedding = spectral_embedding(sims, + norm_laplacian=False, + n_components=n_components, + drop_first=False, + random_state=seed) + + assert np.std(embedding[:, 0]) == pytest.approx(0) + assert np.std(embedding[:, 1]) > 1e-3