Skip to content

Commit

Permalink
MAINT Parameters validation for NeighborhoodComponentsAnalysis (sciki…
Browse files Browse the repository at this point in the history
…t-learn#24195)

Co-authored-by: jeremie du boisberranger <[email protected]>
  • Loading branch information
EdAbati and jeremiedbb committed Sep 1, 2022
1 parent 4114161 commit d7c978b
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 145 deletions.
193 changes: 72 additions & 121 deletions sklearn/neighbors/_nca.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
# License: BSD 3 clause

from warnings import warn
from numbers import Integral, Real
import numpy as np
import sys
import time
import numbers
from scipy.optimize import minimize
from ..utils.extmath import softmax
from ..metrics import pairwise_distances
Expand All @@ -19,7 +19,8 @@
from ..decomposition import PCA
from ..utils.multiclass import check_classification_targets
from ..utils.random import check_random_state
from ..utils.validation import check_is_fitted, check_array, check_scalar
from ..utils.validation import check_is_fitted, check_array
from ..utils._param_validation import Interval, StrOptions
from ..exceptions import ConvergenceWarning


Expand Down Expand Up @@ -176,6 +177,23 @@ class NeighborhoodComponentsAnalysis(
0.961904...
"""

_parameter_constraints: dict = {
"n_components": [
Interval(Integral, 1, None, closed="left"),
None,
],
"init": [
StrOptions({"auto", "pca", "lda", "identity", "random"}),
np.ndarray,
],
"warm_start": ["boolean"],
"max_iter": [Interval(Integral, 1, None, closed="left")],
"tol": [Interval(Real, 0, None, closed="left")],
"callback": [callable, None],
"verbose": ["verbose"],
"random_state": ["random_state"],
}

def __init__(
self,
n_components=None,
Expand Down Expand Up @@ -213,10 +231,59 @@ def fit(self, X, y):
self : object
Fitted estimator.
"""
self._validate_params()

# Verify inputs X and y and NCA parameters, and transform a copy if
# needed
X, y, init = self._validate_params(X, y)
# Validate the inputs X and y, and converts y to numerical classes.
X, y = self._validate_data(X, y, ensure_min_samples=2)
check_classification_targets(y)
y = LabelEncoder().fit_transform(y)

# Check the preferred dimensionality of the projected space
if self.n_components is not None and self.n_components > X.shape[1]:
raise ValueError(
"The preferred dimensionality of the "
f"projected space `n_components` ({self.n_components}) cannot "
"be greater than the given data "
f"dimensionality ({X.shape[1]})!"
)
# If warm_start is enabled, check that the inputs are consistent
if (
self.warm_start
and hasattr(self, "components_")
and self.components_.shape[1] != X.shape[1]
):
raise ValueError(
f"The new inputs dimensionality ({X.shape[1]}) does not "
"match the input dimensionality of the "
f"previously learned transformation ({self.components_.shape[1]})."
)
# Check how the linear transformation should be initialized
init = self.init
if isinstance(init, np.ndarray):
init = check_array(init)
# Assert that init.shape[1] = X.shape[1]
if init.shape[1] != X.shape[1]:
raise ValueError(
f"The input dimensionality ({init.shape[1]}) of the given "
"linear transformation `init` must match the "
f"dimensionality of the given inputs `X` ({X.shape[1]})."
)
# Assert that init.shape[0] <= init.shape[1]
if init.shape[0] > init.shape[1]:
raise ValueError(
f"The output dimensionality ({init.shape[0]}) of the given "
"linear transformation `init` cannot be "
f"greater than its input dimensionality ({init.shape[1]})."
)
# Assert that self.n_components = init.shape[0]
if self.n_components is not None and self.n_components != init.shape[0]:
raise ValueError(
"The preferred dimensionality of the "
f"projected space `n_components` ({self.n_components}) does"
" not match the output dimensionality of "
"the given linear transformation "
f"`init` ({init.shape[0]})!"
)

# Initialize the random generator
self.random_state_ = check_random_state(self.random_state)
Expand Down Expand Up @@ -294,122 +361,6 @@ def transform(self, X):

return np.dot(X, self.components_.T)

def _validate_params(self, X, y):
"""Validate parameters as soon as :meth:`fit` is called.
Parameters
----------
X : array-like of shape (n_samples, n_features)
The training samples.
y : array-like of shape (n_samples,)
The corresponding training labels.
Returns
-------
X : ndarray of shape (n_samples, n_features)
The validated training samples.
y : ndarray of shape (n_samples,)
The validated training labels, encoded to be integers in
the `range(0, n_classes)`.
init : str or ndarray of shape (n_features_a, n_features_b)
The validated initialization of the linear transformation.
Raises
-------
TypeError
If a parameter is not an instance of the desired type.
ValueError
If a parameter's value violates its legal value range or if the
combination of two or more given parameters is incompatible.
"""

# Validate the inputs X and y, and converts y to numerical classes.
X, y = self._validate_data(X, y, ensure_min_samples=2)
check_classification_targets(y)
y = LabelEncoder().fit_transform(y)

# Check the preferred dimensionality of the projected space
if self.n_components is not None:
check_scalar(self.n_components, "n_components", numbers.Integral, min_val=1)

if self.n_components > X.shape[1]:
raise ValueError(
"The preferred dimensionality of the "
"projected space `n_components` ({}) cannot "
"be greater than the given data "
"dimensionality ({})!".format(self.n_components, X.shape[1])
)

# If warm_start is enabled, check that the inputs are consistent
check_scalar(self.warm_start, "warm_start", bool)
if self.warm_start and hasattr(self, "components_"):
if self.components_.shape[1] != X.shape[1]:
raise ValueError(
"The new inputs dimensionality ({}) does not "
"match the input dimensionality of the "
"previously learned transformation ({}).".format(
X.shape[1], self.components_.shape[1]
)
)

check_scalar(self.max_iter, "max_iter", numbers.Integral, min_val=1)
check_scalar(self.tol, "tol", numbers.Real, min_val=0.0)
check_scalar(self.verbose, "verbose", numbers.Integral, min_val=0)

if self.callback is not None:
if not callable(self.callback):
raise ValueError("`callback` is not callable.")

# Check how the linear transformation should be initialized
init = self.init

if isinstance(init, np.ndarray):
init = check_array(init)

# Assert that init.shape[1] = X.shape[1]
if init.shape[1] != X.shape[1]:
raise ValueError(
"The input dimensionality ({}) of the given "
"linear transformation `init` must match the "
"dimensionality of the given inputs `X` ({}).".format(
init.shape[1], X.shape[1]
)
)

# Assert that init.shape[0] <= init.shape[1]
if init.shape[0] > init.shape[1]:
raise ValueError(
"The output dimensionality ({}) of the given "
"linear transformation `init` cannot be "
"greater than its input dimensionality ({}).".format(
init.shape[0], init.shape[1]
)
)

if self.n_components is not None:
# Assert that self.n_components = init.shape[0]
if self.n_components != init.shape[0]:
raise ValueError(
"The preferred dimensionality of the "
"projected space `n_components` ({}) does"
" not match the output dimensionality of "
"the given linear transformation "
"`init` ({})!".format(self.n_components, init.shape[0])
)
elif init in ["auto", "pca", "lda", "identity", "random"]:
pass
else:
raise ValueError(
"`init` must be 'auto', 'pca', 'lda', 'identity', 'random' "
"or a numpy array of shape (n_components, n_features)."
)

return X, y, init

def _initialize(self, X, y, init):
"""Initialize the transformation.
Expand Down
28 changes: 5 additions & 23 deletions sklearn/neighbors/tests/test_nca.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sklearn.datasets import load_iris, make_classification, make_blobs
from sklearn.neighbors import NeighborhoodComponentsAnalysis
from sklearn.metrics import pairwise_distances
from sklearn.preprocessing import LabelEncoder


rng = check_random_state(0)
Expand Down Expand Up @@ -69,7 +70,8 @@ def __init__(self, X, y):
# Initialize a fake NCA and variables needed to compute the loss:
self.fake_nca = NeighborhoodComponentsAnalysis()
self.fake_nca.n_iter_ = np.inf
self.X, y, _ = self.fake_nca._validate_params(X, y)
self.X, y = self.fake_nca._validate_data(X, y, ensure_min_samples=2)
y = LabelEncoder().fit_transform(y)
self.same_class_mask = y[:, np.newaxis] == y[np.newaxis, :]

def callback(self, transformation, n_iter):
Expand Down Expand Up @@ -119,27 +121,6 @@ def test_params_validation():
NCA = NeighborhoodComponentsAnalysis
rng = np.random.RandomState(42)

# TypeError
with pytest.raises(TypeError):
NCA(max_iter="21").fit(X, y)
with pytest.raises(TypeError):
NCA(verbose="true").fit(X, y)
with pytest.raises(TypeError):
NCA(tol="1").fit(X, y)
with pytest.raises(TypeError):
NCA(n_components="invalid").fit(X, y)
with pytest.raises(TypeError):
NCA(warm_start=1).fit(X, y)

# ValueError
msg = (
r"`init` must be 'auto', 'pca', 'lda', 'identity', 'random' or a "
r"numpy array of shape (n_components, n_features)."
)
with pytest.raises(ValueError, match=re.escape(msg)):
NCA(init=1).fit(X, y)
with pytest.raises(ValueError, match="max_iter == -1, must be >= 1."):
NCA(max_iter=-1).fit(X, y)
init = rng.rand(5, 3)
msg = (
f"The output dimensionality ({init.shape[0]}) "
Expand Down Expand Up @@ -513,7 +494,8 @@ def __init__(self, X, y):
# function:
self.fake_nca = NeighborhoodComponentsAnalysis()
self.fake_nca.n_iter_ = np.inf
self.X, y, _ = self.fake_nca._validate_params(X, y)
self.X, y = self.fake_nca._validate_data(X, y, ensure_min_samples=2)
y = LabelEncoder().fit_transform(y)
self.same_class_mask = y[:, np.newaxis] == y[np.newaxis, :]

def callback(self, transformation, n_iter):
Expand Down
1 change: 0 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
"MiniBatchDictionaryLearning",
"MultiTaskElasticNet",
"MultiTaskLasso",
"NeighborhoodComponentsAnalysis",
"Nystroem",
"OAS",
"OPTICS",
Expand Down

0 comments on commit d7c978b

Please sign in to comment.