Skip to content

Commit

Permalink
MAINT dynamically expose kulsinski and remove support in BallTree (sc…
Browse files Browse the repository at this point in the history
…ikit-learn#25417)

Co-authored-by: Loïc Estève <[email protected]>
Co-authored-by: Julien Jerphanion <[email protected]>
closes scikit-learn#25212
  • Loading branch information
glemaitre committed Jan 26, 2023
1 parent b69abf5 commit 8640ed7
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 33 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ Changelog
when the provided `sample_weight` reduces the problem to a single class in `fit`.
:pr:`24140` by :user:`Jonathan Ohayon <Johayon>` and :user:`Chiara Marmo <cmarmo>`.

:mod:`sklearn.neighbors`
........................

- |Fix| Remove support for `KulsinskiDistance` in :class:`neighbors.BallTree`. This
dissimilarity is not a metric and cannot be supported by the BallTree.
:pr:`25417` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.pipeline`
.......................

Expand Down
6 changes: 6 additions & 0 deletions sklearn/cluster/_optics.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class OPTICS(ClusterMixin, BaseEstimator):
See the documentation for scipy.spatial.distance for details on these
metrics.
.. note::
`'kulsinski'` is deprecated from SciPy 1.9 and will removed in SciPy 1.11.
p : float, default=2
Parameter for the Minkowski metric from
:class:`~sklearn.metrics.pairwise_distances`. When p = 1, this is
Expand Down Expand Up @@ -489,6 +492,9 @@ def compute_optics_graph(
See the documentation for scipy.spatial.distance for details on these
metrics.
.. note::
`'kulsinski'` is deprecated from SciPy 1.9 and will be removed in SciPy 1.11.
p : int, default=2
Parameter for the Minkowski metric from
:class:`~sklearn.metrics.pairwise_distances`. When p = 1, this is
Expand Down
5 changes: 4 additions & 1 deletion sklearn/metrics/_dist_metrics.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ from ..utils._typedefs cimport DTYPE_t, ITYPE_t, DTYPECODE
from ..utils._typedefs import DTYPE, ITYPE
from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper
from ..utils import check_array
from ..utils.fixes import parse_version, sp_base_version

cdef inline double fmax(double a, double b) nogil:
return max(a, b)
Expand All @@ -59,12 +60,14 @@ BOOL_METRICS = [
"matching",
"jaccard",
"dice",
"kulsinski",
"rogerstanimoto",
"russellrao",
"sokalmichener",
"sokalsneath",
]
if sp_base_version < parse_version("1.11"):
# Deprecated in SciPy 1.9 and removed in SciPy 1.11
BOOL_METRICS += ["kulsinski"]

def get_valid_metric_ids(L):
"""Given an iterable of metric class names or class identifiers,
Expand Down
19 changes: 16 additions & 3 deletions sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..preprocessing import normalize
from ..utils._mask import _get_mask
from ..utils.parallel import delayed, Parallel
from ..utils.fixes import sp_version, parse_version
from ..utils.fixes import sp_base_version, sp_version, parse_version
from ..utils._param_validation import validate_params

from ._pairwise_distances_reduction import ArgKmin
Expand Down Expand Up @@ -644,6 +644,9 @@ def pairwise_distances_argmin_min(
See the documentation for scipy.spatial.distance for details on these
metrics.
.. note::
`'kulsinski'` is deprecated from SciPy 1.9 and will be removed in SciPy 1.11.
metric_kwargs : dict, default=None
Keyword arguments to pass to specified metric function.
Expand Down Expand Up @@ -761,6 +764,9 @@ def pairwise_distances_argmin(X, Y, *, axis=1, metric="euclidean", metric_kwargs
See the documentation for scipy.spatial.distance for details on these
metrics.
.. note::
`'kulsinski'` is deprecated from SciPy 1.9 and will be removed in SciPy 1.11.
metric_kwargs : dict, default=None
Keyword arguments to pass to specified metric function.
Expand Down Expand Up @@ -1639,7 +1645,6 @@ def _pairwise_callable(X, Y, metric, force_all_finite=True, **kwds):
"dice",
"hamming",
"jaccard",
"kulsinski",
"mahalanobis",
"matching",
"minkowski",
Expand All @@ -1654,6 +1659,9 @@ def _pairwise_callable(X, Y, metric, force_all_finite=True, **kwds):
"nan_euclidean",
"haversine",
]
if sp_base_version < parse_version("1.11"):
# Deprecated in SciPy 1.9 and removed in SciPy 1.11
_VALID_METRICS += ["kulsinski"]

_NAN_METRICS = ["nan_euclidean"]

Expand Down Expand Up @@ -1908,6 +1916,9 @@ def pairwise_distances(
See the documentation for scipy.spatial.distance for details on these
metrics. These metrics do not support sparse matrix inputs.
.. note::
`'kulsinski'` is deprecated from SciPy 1.9 and will be removed in SciPy 1.11.
Note that in the case of 'cityblock', 'cosine' and 'euclidean' (which are
valid scipy.spatial.distance metrics), the scikit-learn implementation
will be used, which is faster and has support for sparse matrices (except
Expand Down Expand Up @@ -2043,14 +2054,16 @@ def pairwise_distances(
PAIRWISE_BOOLEAN_FUNCTIONS = [
"dice",
"jaccard",
"kulsinski",
"matching",
"rogerstanimoto",
"russellrao",
"sokalmichener",
"sokalsneath",
"yule",
]
if sp_base_version < parse_version("1.11"):
# Deprecated in SciPy 1.9 and removed in SciPy 1.11
PAIRWISE_BOOLEAN_FUNCTIONS += ["kulsinski"]

# Helper functions - distance
PAIRWISE_KERNEL_FUNCTIONS = {
Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/tests/test_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_pairwise_boolean_distance(metric):
with ignore_warnings(category=DataConversionWarning):
for Z in [Y, None]:
res = pairwise_distances(X, Z, metric=metric)
res[np.isnan(res)] = 0
np.nan_to_num(res, nan=0, posinf=0, neginf=0, copy=False)
assert np.sum(res != 0) == 0

# non-boolean arrays are converted to boolean for boolean
Expand Down
2 changes: 1 addition & 1 deletion sklearn/neighbors/_ball_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ VALID_METRICS = ['EuclideanDistance', 'SEuclideanDistance',
'MahalanobisDistance', 'HammingDistance',
'CanberraDistance', 'BrayCurtisDistance',
'JaccardDistance', 'MatchingDistance',
'DiceDistance', 'KulsinskiDistance',
'DiceDistance',
'RogersTanimotoDistance', 'RussellRaoDistance',
'SokalMichenerDistance', 'SokalSneathDistance',
'PyFuncDistance', 'HaversineDistance']
Expand Down
52 changes: 26 additions & 26 deletions sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,39 +38,39 @@
from ..utils.validation import check_non_negative
from ..utils._param_validation import Interval, StrOptions
from ..utils.parallel import delayed, Parallel
from ..utils.fixes import parse_version, sp_version
from ..utils.fixes import parse_version, sp_base_version, sp_version
from ..exceptions import DataConversionWarning, EfficiencyWarning

SCIPY_METRICS = [
"braycurtis",
"canberra",
"chebyshev",
"correlation",
"cosine",
"dice",
"hamming",
"jaccard",
"mahalanobis",
"matching",
"minkowski",
"rogerstanimoto",
"russellrao",
"seuclidean",
"sokalmichener",
"sokalsneath",
"sqeuclidean",
"yule",
]
if sp_base_version < parse_version("1.11"):
# Deprecated in SciPy 1.9 and removed in SciPy 1.11
SCIPY_METRICS += ["kulsinski"]

VALID_METRICS = dict(
ball_tree=BallTree.valid_metrics,
kd_tree=KDTree.valid_metrics,
# The following list comes from the
# sklearn.metrics.pairwise doc string
brute=sorted(
set(PAIRWISE_DISTANCE_FUNCTIONS).union(
[
"braycurtis",
"canberra",
"chebyshev",
"correlation",
"cosine",
"dice",
"hamming",
"jaccard",
"kulsinski",
"mahalanobis",
"matching",
"minkowski",
"rogerstanimoto",
"russellrao",
"seuclidean",
"sokalmichener",
"sokalsneath",
"sqeuclidean",
"yule",
]
)
),
brute=sorted(set(PAIRWISE_DISTANCE_FUNCTIONS).union(SCIPY_METRICS)),
)

# TODO: Remove filterwarnings in 1.3 when wminkowski is removed
Expand Down
1 change: 0 additions & 1 deletion sklearn/neighbors/tests/test_ball_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"matching",
"jaccard",
"dice",
"kulsinski",
"rogerstanimoto",
"russellrao",
"sokalmichener",
Expand Down
1 change: 1 addition & 0 deletions sklearn/utils/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

np_version = parse_version(np.__version__)
sp_version = parse_version(scipy.__version__)
sp_base_version = parse_version(sp_version.base_version)


if sp_version >= parse_version("1.4"):
Expand Down

0 comments on commit 8640ed7

Please sign in to comment.