Skip to content

Commit

Permalink
MAINT fix the way to call stats.mode (scikit-learn#23633)
Browse files Browse the repository at this point in the history
Co-authored-by: Olivier Grisel <[email protected]>
Co-authored-by: Meekail Zain <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
  • Loading branch information
4 people committed Aug 5, 2022
1 parent f2e1d9d commit 02a4b34
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 6 deletions.
4 changes: 2 additions & 2 deletions sklearn/impute/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import numpy as np
import numpy.ma as ma
from scipy import sparse as sp
from scipy import stats

from ..base import BaseEstimator, TransformerMixin
from ..utils._param_validation import StrOptions
from ..utils.fixes import _mode
from ..utils.sparsefuncs import _get_median
from ..utils.validation import check_is_fitted
from ..utils.validation import FLOAT_DTYPES
Expand Down Expand Up @@ -52,7 +52,7 @@ def _most_frequent(array, extra_value, n_repeat):
if count == most_frequent_count
)
else:
mode = stats.mode(array)
mode = _mode(array)
most_frequent_value = mode[0][0]
most_frequent_count = mode[1][0]
else:
Expand Down
4 changes: 2 additions & 2 deletions sklearn/neighbors/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from numbers import Integral

import numpy as np
from scipy import stats
from ..utils.fixes import _mode
from ..utils.extmath import weighted_mode
from ..utils.validation import _is_arraylike, _num_samples

Expand Down Expand Up @@ -249,7 +249,7 @@ def predict(self, X):
y_pred = np.empty((n_queries, n_outputs), dtype=classes_[0].dtype)
for k, classes_k in enumerate(classes_):
if weights is None:
mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
mode, _ = _mode(_y[neigh_ind, k], axis=1)
else:
mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1)

Expand Down
7 changes: 7 additions & 0 deletions sklearn/utils/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,10 @@ def threadpool_info():


threadpool_info.__doc__ = threadpoolctl.threadpool_info.__doc__


# TODO: Remove when SciPy 1.9 is the minimum supported version
def _mode(a, axis=0):
if sp_version >= parse_version("1.9.0"):
return scipy.stats.mode(a, axis=axis, keepdims=True)
return scipy.stats.mode(a, axis=axis)
4 changes: 2 additions & 2 deletions sklearn/utils/tests/test_extmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
from scipy import sparse
from scipy import linalg
from scipy import stats
from scipy.sparse.linalg import eigsh
from scipy.special import expit

Expand All @@ -19,6 +18,7 @@
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import skip_if_32bit
from sklearn.utils.fixes import _mode

from sklearn.utils.extmath import density, _safe_accumulator_op
from sklearn.utils.extmath import randomized_svd, _randomized_eigsh
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_uniform_weights():
weights = np.ones(x.shape)

for axis in (None, 0, 1):
mode, score = stats.mode(x, axis)
mode, score = _mode(x, axis)
mode2, score2 = weighted_mode(x, weights, axis=axis)

assert_array_equal(mode, mode2)
Expand Down

0 comments on commit 02a4b34

Please sign in to comment.