Skip to content

Commit

Permalink
MAINT Parameters validation for sklearn.manifold.trustworthiness (sci…
Browse files Browse the repository at this point in the history
…kit-learn#26276)

Co-authored-by: jeremie du boisberranger <[email protected]>
  • Loading branch information
ROMEEZHOU and jeremiedbb committed Jun 30, 2023
1 parent 6f2cf7c commit 9cabb12
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
15 changes: 12 additions & 3 deletions sklearn/manifold/_t_sne.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from ..neighbors import NearestNeighbors
from ..utils import check_random_state
from ..utils._openmp_helpers import _openmp_effective_n_threads
from ..utils._param_validation import Interval, StrOptions
from ..utils.validation import check_non_negative
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils.validation import _num_samples, check_non_negative

# mypy error: Module 'sklearn.manifold' has no attribute '_utils'
# mypy error: Module 'sklearn.manifold' has no attribute '_barnes_hut_tsne'
Expand Down Expand Up @@ -446,6 +446,15 @@ def _gradient_descent(
return p, error, i


@validate_params(
{
"X": ["array-like", "sparse matrix"],
"X_embedded": ["array-like", "sparse matrix"],
"n_neighbors": [Interval(Integral, 1, None, closed="left")],
"metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable],
},
prefer_skip_nested_validation=True,
)
def trustworthiness(X, X_embedded, *, n_neighbors=5, metric="euclidean"):
r"""Indicate to what extent the local structure is retained.
Expand Down Expand Up @@ -504,7 +513,7 @@ def trustworthiness(X, X_embedded, *, n_neighbors=5, metric="euclidean"):
Local Structure. Proceedings of the Twelfth International Conference on
Artificial Intelligence and Statistics, PMLR 5:384-391, 2009.
"""
n_samples = X.shape[0]
n_samples = _num_samples(X)
if n_neighbors >= n_samples / 2:
raise ValueError(
f"n_neighbors ({n_neighbors}) should be less than n_samples / 2"
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def _check_function_param_validation(
"sklearn.linear_model.orthogonal_mp",
"sklearn.linear_model.orthogonal_mp_gram",
"sklearn.linear_model.ridge_regression",
"sklearn.manifold.trustworthiness",
"sklearn.metrics.accuracy_score",
"sklearn.manifold.smacof",
"sklearn.metrics.auc",
Expand Down

0 comments on commit 9cabb12

Please sign in to comment.