Skip to content

Commit

Permalink
Fix spectral embedding implementation (scikit-learn#9062)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmargeta authored and lesteve committed Dec 18, 2017
1 parent 0a85111 commit d01cdc2
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 71 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ matrix:
COVERAGE=true
if: type != cron
# This environment tests the newest supported Anaconda release (5.0.0)
# It also runs tests requiring Pandas.
# It also runs tests requiring Pandas and PyAMG
- env: DISTRIB="conda" PYTHON_VERSION="3.6.2" INSTALL_MKL="true"
NUMPY_VERSION="1.13.1" SCIPY_VERSION="0.19.1" PANDAS_VERSION="0.20.3"
CYTHON_VERSION="0.26.1" COVERAGE=true
CYTHON_VERSION="0.26.1" PYAMG_VERSION="3.3.2" COVERAGE=true
CHECK_PYTEST_SOFT_DEPENDENCY="true"
if: type != cron
# flake8 linting on diff wrt common ancestor with upstream/master
Expand Down
27 changes: 16 additions & 11 deletions build_tools/travis/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,25 @@ if [[ "$DISTRIB" == "conda" ]]; then
export PATH=$MINICONDA_PATH/bin:$PATH
conda update --yes conda

# Configure the conda environment and put it in the path using the
# provided versions
TO_INSTALL="python=$PYTHON_VERSION pip pytest pytest-cov \
numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION \
cython=$CYTHON_VERSION"

if [[ "$INSTALL_MKL" == "true" ]]; then
conda create -n testenv --yes python=$PYTHON_VERSION pip \
pytest pytest-cov numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION \
mkl cython=$CYTHON_VERSION \
${PANDAS_VERSION+pandas=$PANDAS_VERSION}

TO_INSTALL="$TO_INSTALL mkl"
else
conda create -n testenv --yes python=$PYTHON_VERSION pip \
pytest pytest-cov numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION \
nomkl cython=$CYTHON_VERSION \
${PANDAS_VERSION+pandas=$PANDAS_VERSION}
TO_INSTALL="$TO_INSTALL nomkl"
fi

if [[ -n "$PANDAS_VERSION" ]]; then
TO_INSTALL="$TO_INSTALL pandas=$PANDAS_VERSION"
fi

if [[ -n "$PYAMG_VERSION" ]]; then
TO_INSTALL="$TO_INSTALL pyamg=$PYAMG_VERSION"
fi

conda create -n testenv --yes $TO_INSTALL
source activate testenv

elif [[ "$DISTRIB" == "ubuntu" ]]; then
Expand Down
11 changes: 8 additions & 3 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -225,16 +225,21 @@ Decomposition, manifold learning and clustering
- Fixed a bug when setting parameters on meta-estimator, involving both a
wrapped estimator and its parameter. :issue:`9999` by :user:`Marcus Voss
<marcus-voss>` and `Joel Nothman`_.

- ``k_means`` now gives a warning, if the number of distinct clusters found
is smaller than ``n_clusters``. This may occur when the number of distinct
points in the data set is actually smaller than the number of cluster one is
is smaller than ``n_clusters``. This may occur when the number of distinct
points in the data set is actually smaller than the number of cluster one is
looking for. :issue:`10059` by :user:`Christian Braune <christianbraune79>`.

- Fixed a bug in :func:`datasets.make_circles`, where no odd number of data
points could be generated. :issue:`10037` by :user:`Christian Braune
<christianbraune79>`_.
- Fixed a bug in :func:`cluster.spectral_clustering` where the normalization of
the spectrum was using a division instead of a multiplication. :issue:`8129`
by :user:`Jan Margeta <jmargeta>`, :user:`Guillaume Lemaitre <glemaitre>`,
and :user:`Devansh D. <devanshdalal>`.

Metrics

- Fixed a bug due to floating point error in :func:`metrics.roc_auc_score` with
Expand Down
2 changes: 1 addition & 1 deletion examples/cluster/plot_face_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
for assign_labels in ('kmeans', 'discretize'):
t0 = time.time()
labels = spectral_clustering(graph, n_clusters=N_REGIONS,
assign_labels=assign_labels, random_state=1)
assign_labels=assign_labels, random_state=42)
t1 = time.time()
labels = labels.reshape(face.shape)

Expand Down
4 changes: 4 additions & 0 deletions sklearn/cluster/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ def spectral_clustering(affinity, n_clusters=8, n_components=None,

random_state = check_random_state(random_state)
n_components = n_clusters if n_components is None else n_components

# The first eigen vector is constant only for fully connected graphs
# and should be kept for spectral clustering (drop_first = False)
# See spectral_embedding documentation.
maps = spectral_embedding(affinity, n_components=n_components,
eigen_solver=eigen_solver,
random_state=random_state,
Expand Down
103 changes: 56 additions & 47 deletions sklearn/cluster/tests/test_spectral.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
"""Testing for Spectral Clustering methods"""

from sklearn.externals.six.moves import cPickle

dumps, loads = cPickle.dumps, cPickle.loads
from __future__ import division

import numpy as np
from scipy import sparse

from sklearn.externals.six.moves import cPickle

from sklearn.utils import check_random_state
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import assert_warns_message

from sklearn.cluster import SpectralClustering, spectral_clustering
from sklearn.cluster.spectral import spectral_embedding
from sklearn.cluster.spectral import discretize
from sklearn.feature_extraction import img_to_graph
from sklearn.metrics import pairwise_distances
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics.pairwise import kernel_metrics, rbf_kernel
from sklearn.datasets.samples_generator import make_blobs

try:
from pyamg import smoothed_aggregation_solver # noqa
amg_loaded = True
except ImportError:
amg_loaded = False


def test_spectral_clustering():
S = np.array([[1.0, 1.0, 1.0, 0.2, 0.0, 0.0, 0.0],
Expand All @@ -44,44 +48,14 @@ def test_spectral_clustering():
if labels[0] == 0:
labels = 1 - labels

assert_array_equal(labels, [1, 1, 1, 0, 0, 0, 0])
assert adjusted_rand_score(labels, [1, 1, 1, 0, 0, 0, 0]) == 1

model_copy = loads(dumps(model))
assert_equal(model_copy.n_clusters, model.n_clusters)
assert_equal(model_copy.eigen_solver, model.eigen_solver)
model_copy = cPickle.loads(cPickle.dumps(model))
assert model_copy.n_clusters == model.n_clusters
assert model_copy.eigen_solver == model.eigen_solver
assert_array_equal(model_copy.labels_, model.labels_)


def test_spectral_amg_mode():
# Test the amg mode of SpectralClustering
centers = np.array([
[0., 0., 0.],
[10., 10., 10.],
[20., 20., 20.],
])
X, true_labels = make_blobs(n_samples=100, centers=centers,
cluster_std=1., random_state=42)
D = pairwise_distances(X) # Distance matrix
S = np.max(D) - D # Similarity matrix
S = sparse.coo_matrix(S)
try:
from pyamg import smoothed_aggregation_solver # noqa

amg_loaded = True
except ImportError:
amg_loaded = False
if amg_loaded:
labels = spectral_clustering(S, n_clusters=len(centers),
random_state=0, eigen_solver="amg")
# We don't care too much that it's good, just that it *worked*.
# There does have to be some lower limit on the performance though.
assert_greater(np.mean(labels == true_labels), .3)
else:
assert_raises(ValueError, spectral_embedding, S,
n_components=len(centers),
random_state=0, eigen_solver="amg")


def test_spectral_unknown_mode():
# Test that SpectralClustering fails with an unknown mode set.
centers = np.array([
Expand Down Expand Up @@ -124,7 +98,7 @@ def test_spectral_clustering_sparse():

labels = SpectralClustering(random_state=0, n_clusters=2,
affinity='precomputed').fit(S).labels_
assert_equal(adjusted_rand_score(y, labels), 1)
assert adjusted_rand_score(y, labels) == 1


def test_affinities():
Expand All @@ -138,11 +112,11 @@ def test_affinities():
sp = SpectralClustering(n_clusters=2, affinity='nearest_neighbors',
random_state=0)
assert_warns_message(UserWarning, 'not fully connected', sp.fit, X)
assert_equal(adjusted_rand_score(y, sp.labels_), 1)
assert adjusted_rand_score(y, sp.labels_) == 1

sp = SpectralClustering(n_clusters=2, gamma=2, random_state=0)
labels = sp.fit(X).labels_
assert_equal(adjusted_rand_score(y, labels), 1)
assert adjusted_rand_score(y, labels) == 1

X = check_random_state(10).rand(10, 5) * 10

Expand All @@ -154,12 +128,12 @@ def test_affinities():
sp = SpectralClustering(n_clusters=2, affinity=kern,
random_state=0)
labels = sp.fit(X).labels_
assert_equal((X.shape[0],), labels.shape)
assert (X.shape[0],) == labels.shape

sp = SpectralClustering(n_clusters=2, affinity=lambda x, y: 1,
random_state=0)
labels = sp.fit(X).labels_
assert_equal((X.shape[0],), labels.shape)
assert (X.shape[0],) == labels.shape

def histogram(x, y, **kwargs):
# Histogram kernel implemented as a callable.
Expand All @@ -168,7 +142,7 @@ def histogram(x, y, **kwargs):

sp = SpectralClustering(n_clusters=2, affinity=histogram, random_state=0)
labels = sp.fit(X).labels_
assert_equal((X.shape[0],), labels.shape)
assert (X.shape[0],) == labels.shape

# raise error on unknown affinity
sp = SpectralClustering(n_clusters=2, affinity='<unknown>')
Expand All @@ -193,4 +167,39 @@ def test_discretize(seed=8):
+ 0.1 * random_state.randn(n_samples,
n_class + 1))
y_pred = discretize(y_true_noisy, random_state)
assert_greater(adjusted_rand_score(y_true, y_pred), 0.8)
assert adjusted_rand_score(y_true, y_pred) > 0.8


def test_spectral_clustering_with_arpack_amg_solvers():
# Test that spectral_clustering is the same for arpack and amg solver
# Based on toy example from plot_segmentation_toy.py

# a small two coin image
x, y = np.indices((40, 40))

center1, center2 = (14, 12), (20, 25)
radius1, radius2 = 8, 7

circle1 = (x - center1[0]) ** 2 + (y - center1[1]) ** 2 < radius1 ** 2
circle2 = (x - center2[0]) ** 2 + (y - center2[1]) ** 2 < radius2 ** 2

circles = circle1 | circle2
mask = circles.copy()
img = circles.astype(float)

graph = img_to_graph(img, mask=mask)
graph.data = np.exp(-graph.data / graph.data.std())

labels_arpack = spectral_clustering(
graph, n_clusters=2, eigen_solver='arpack', random_state=0)

assert len(np.unique(labels_arpack)) == 2

if amg_loaded:
labels_amg = spectral_clustering(
graph, n_clusters=2, eigen_solver='amg', random_state=0)
assert adjusted_rand_score(labels_arpack, labels_amg) == 1
else:
assert_raises(
ValueError, spectral_clustering,
graph, n_clusters=2, eigen_solver='amg', random_state=0)
18 changes: 14 additions & 4 deletions sklearn/manifold/spectral_embedding_.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Wei LI <[email protected]>
# License: BSD 3 clause

from __future__ import division

import warnings

import numpy as np
Expand Down Expand Up @@ -269,7 +271,9 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None,
lambdas, diffusion_map = eigsh(laplacian, k=n_components,
sigma=1.0, which='LM',
tol=eigen_tol, v0=v0)
embedding = diffusion_map.T[n_components::-1] * dd
embedding = diffusion_map.T[n_components::-1]
if norm_laplacian:
embedding = embedding / dd
except RuntimeError:
# When submatrices are exactly singular, an LU decomposition
# in arpack fails. We fallback to lobpcg
Expand All @@ -292,7 +296,9 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None,
X[:, 0] = dd.ravel()
lambdas, diffusion_map = lobpcg(laplacian, X, M=M, tol=1.e-12,
largest=False)
embedding = diffusion_map.T * dd
embedding = diffusion_map.T
if norm_laplacian:
embedding = embedding / dd
if embedding.shape[0] == 1:
raise ValueError

Expand All @@ -307,7 +313,9 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None,
if sparse.isspmatrix(laplacian):
laplacian = laplacian.toarray()
lambdas, diffusion_map = eigh(laplacian)
embedding = diffusion_map.T[:n_components] * dd
embedding = diffusion_map.T[:n_components]
if norm_laplacian:
embedding = embedding / dd
else:
laplacian = _set_diag(laplacian, 1, norm_laplacian)
# We increase the number of eigenvectors requested, as lobpcg
Expand All @@ -316,7 +324,9 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None,
X[:, 0] = dd.ravel()
lambdas, diffusion_map = lobpcg(laplacian, X, tol=1e-15,
largest=False, maxiter=2000)
embedding = diffusion_map.T[:n_components] * dd
embedding = diffusion_map.T[:n_components]
if norm_laplacian:
embedding = embedding / dd
if embedding.shape[0] == 1:
raise ValueError

Expand Down
27 changes: 24 additions & 3 deletions sklearn/manifold/tests/test_spectral_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

import numpy as np
from numpy.testing import assert_array_almost_equal
from numpy.testing import assert_array_equal

from scipy import sparse
from scipy.sparse import csgraph
Expand All @@ -15,6 +15,8 @@
from sklearn.cluster import KMeans
from sklearn.datasets.samples_generator import make_blobs
from sklearn.utils.extmath import _deterministic_vector_sign_flip
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_true, assert_equal, assert_raises
from sklearn.utils.testing import SkipTest

Expand Down Expand Up @@ -258,7 +260,26 @@ def test_spectral_embedding_unnormalized():
laplacian, dd = csgraph.laplacian(sims, normed=False,
return_diag=True)
_, diffusion_map = eigh(laplacian)
embedding_2 = diffusion_map.T[:n_components] * dd
embedding_2 = diffusion_map.T[:n_components]
embedding_2 = _deterministic_vector_sign_flip(embedding_2).T

assert_array_almost_equal(embedding_1, embedding_2)


def test_spectral_embedding_first_eigen_vector():
# Test that the first eigenvector of spectral_embedding
# is constant and that the second is not (for a connected graph)
random_state = np.random.RandomState(36)
data = random_state.randn(10, 30)
sims = rbf_kernel(data)
n_components = 2

for seed in range(10):
embedding = spectral_embedding(sims,
norm_laplacian=False,
n_components=n_components,
drop_first=False,
random_state=seed)

assert np.std(embedding[:, 0]) == pytest.approx(0)
assert np.std(embedding[:, 1]) > 1e-3

0 comments on commit d01cdc2

Please sign in to comment.