Skip to content

Commit

Permalink
Add pairwise_distances_argmin.
Browse files Browse the repository at this point in the history
  • Loading branch information
mblondel committed Aug 30, 2013
1 parent 70695ca commit 3d772e9
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 17 deletions.
8 changes: 4 additions & 4 deletions examples/cluster/plot_color_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import numpy as np
import pylab as pl
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin_min
from sklearn.metrics import pairwise_distances_argmin
from sklearn.datasets import load_sample_image
from sklearn.utils import shuffle
from time import time
Expand Down Expand Up @@ -64,9 +64,9 @@
codebook_random = shuffle(image_array, random_state=0)[:n_colors + 1]
print("Predicting color indices on the full image (random)")
t0 = time()
labels_random = pairwise_distances_argmin_min(codebook_random,
image_array,
axis=0)[0]
labels_random = pairwise_distances_argmin(codebook_random,
image_array,
axis=0)
print("done in %0.3fs." % (time() - t0))


Expand Down
6 changes: 3 additions & 3 deletions examples/cluster/plot_mini_batch_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pylab as pl

from sklearn.cluster import MiniBatchKMeans, KMeans
from sklearn.metrics.pairwise import pairwise_distances_argmin_min
from sklearn.metrics.pairwise import pairwise_distances_argmin
from sklearn.datasets.samples_generator import make_blobs

##############################################################################
Expand Down Expand Up @@ -66,8 +66,8 @@
# MiniBatchKMeans and the KMeans algorithm. Let's pair the cluster centers per
# closest one.

order = pairwise_distances_argmin_min(k_means_cluster_centers,
mbk_means_cluster_centers)[0]
order = pairwise_distances_argmin(k_means_cluster_centers,
mbk_means_cluster_centers)

# KMeans
ax = fig.add_subplot(1, 3, 1)
Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/affinity_propagation_.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..base import BaseEstimator, ClusterMixin
from ..utils import as_float_array
from ..metrics import euclidean_distances
from ..metrics import pairwise_distances_argmin_min
from ..metrics import pairwise_distances_argmin


def affinity_propagation(S, preference=None, convergence_iter=15, max_iter=200,
Expand Down Expand Up @@ -302,4 +302,4 @@ def predict(self, X):
raise ValueError("Predict method is not supported when "
"affinity='precomputed'.")

return pairwise_distances_argmin_min(X, self.cluster_centers_)[0]
return pairwise_distances_argmin(X, self.cluster_centers_)
4 changes: 2 additions & 2 deletions sklearn/cluster/mean_shift_.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..base import BaseEstimator, ClusterMixin
from ..neighbors import NearestNeighbors
from ..metrics.pairwise import euclidean_distances
from ..metrics.pairwise import pairwise_distances_argmin_min
from ..metrics.pairwise import pairwise_distances_argmin


def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0):
Expand Down Expand Up @@ -295,4 +295,4 @@ def predict(self, X):
labels : array, shape [n_samples,]
Index of the cluster each sample belongs to.
"""
return pairwise_distances_argmin_min(X, self.cluster_centers_)[0]
return pairwise_distances_argmin(X, self.cluster_centers_)
1 change: 1 addition & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .pairwise import (euclidean_distances,
pairwise_distances,
pairwise_distances_argmin_min,
pairwise_distances_argmin,
pairwise_kernels)

__all__ = ['accuracy_score',
Expand Down
83 changes: 77 additions & 6 deletions sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,10 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
This is mostly equivalent to calling:
pairwise_distances(X, Y=Y, metric=metric).argmin(axis=axis)
(pairwise_distances(X, Y=Y, metric=metric).argmin(axis=axis),
pairwise_distances(X, Y=Y, metric=metric).min(axis=axis))
but uses much less memory, and is faster for large arrays. It also returns
the minimum values at the same time.
but uses much less memory, and is faster for large arrays.
This function works with dense 2D arrays only.
Expand Down Expand Up @@ -259,9 +259,7 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
See also
========
sklearn.metrics.pairwise_distances
Notes
=====
sklearn.metrics.pairwise_distances_argmin
"""
dist_func = None
if metric in PAIRWISE_DISTANCE_FUNCTIONS:
Expand Down Expand Up @@ -317,6 +315,79 @@ def pairwise_distances_argmin_min(X, Y, axis=1, metric="euclidean",
return indices, values


def pairwise_distances_argmin(X, Y, axis=1, metric="euclidean",
batch_size=500, metric_kwargs={}):
"""Compute minimum distances between one point and a set of points.
This function computes for each row in X, the index of the row of Y which
is closest (according to the specified distance).
This is mostly equivalent to calling:
pairwise_distances(X, Y=Y, metric=metric).argmin(axis=axis)
but uses much less memory, and is faster for large arrays.
This function works with dense 2D arrays only.
Parameters
==========
X, Y : array-like
Arrays containing points. Respective shapes (n_samples1, n_features)
and (n_samples2, n_features)
batch_size : integer
To reduce memory consumption over the naive solution, data are
processed in batches, comprising batch_size rows of X and
batch_size rows of Y. The default value is quite conservative, but
can be changed for fine-tuning. The larger the number, the larger the
memory usage.
metric : string or callable
metric to use for distance computation. Any metric from scikit-learn
or scipy.spatial.distance can be used.
If metric is a callable function, it is called on each
pair of instances (rows) and the resulting value recorded. The callable
should take two arrays as input and return one value indicating the
distance between them. This works for Scipy's metrics, but is less
efficient than passing the metric name as a string.
Distance matrices are not supported.
Valid values for metric are:
- from scikit-learn: ['cityblock', 'cosine', 'euclidean', 'l1', 'l2',
'manhattan']
- from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev',
'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski',
'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao',
'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule']
See the documentation for scipy.spatial.distance for details on these
metrics.
metric_kwargs : dict
keyword arguments to pass to specified metric function.
Returns
=======
argmin : numpy.ndarray
Y[argmin[i], :] is the row in Y that is closest to X[i, :].
distances : numpy.ndarray
distances[i] is the distance between the i-th row in X and the
argmin[i]-th row in Y.
See also
========
sklearn.metrics.pairwise_distances
sklearn.metrics.pairwise_distances_argmin_min
"""
return pairwise_distances_argmin_min(X, Y, axis, metric, batch_size,
metric_kwargs)[0]


def manhattan_distances(X, Y=None, sum_over_features=True,
size_threshold=5e8):
""" Compute the L1 distances between the vectors in X and Y.
Expand Down

0 comments on commit 3d772e9

Please sign in to comment.