Skip to content

Commit

Permalink
MAINT validate parameter in KernelPCA (scikit-learn#24020)
Browse files Browse the repository at this point in the history
Co-authored-by: Julien Jerphanion <[email protected]>
Co-authored-by: jeremiedbb <[email protected]>
  • Loading branch information
3 people committed Jul 28, 2022
1 parent 6e99407 commit 3312bc2
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 36 deletions.
47 changes: 40 additions & 7 deletions sklearn/decomposition/_kernel_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# License: BSD 3 clause

import numpy as np
import numbers
from numbers import Integral, Real
from scipy import linalg
from scipy.sparse.linalg import eigsh

Expand All @@ -14,8 +14,8 @@
from ..utils.validation import (
check_is_fitted,
_check_psd_eigenvalues,
check_scalar,
)
from ..utils._param_validation import Interval, StrOptions
from ..utils.deprecation import deprecated
from ..exceptions import NotFittedError
from ..base import BaseEstimator, TransformerMixin, _ClassNamePrefixFeaturesOutMixin
Expand All @@ -42,8 +42,8 @@ class KernelPCA(_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimato
n_components : int, default=None
Number of components. If None, all non-zero components are kept.
kernel : {'linear', 'poly', \
'rbf', 'sigmoid', 'cosine', 'precomputed'}, default='linear'
kernel : {'linear', 'poly', 'rbf', 'sigmoid', 'cosine', 'precomputed'} \
or callable, default='linear'
Kernel used for PCA.
gamma : float, default=None
Expand Down Expand Up @@ -239,6 +239,40 @@ class KernelPCA(_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimato
(1797, 7)
"""

_parameter_constraints = {
"n_components": [
Interval(Integral, 1, None, closed="left"),
None,
],
"kernel": [
StrOptions({"linear", "poly", "rbf", "sigmoid", "cosine", "precomputed"}),
callable,
],
"gamma": [
Interval(Real, 0, None, closed="left"),
None,
],
"degree": [Interval(Integral, 0, None, closed="left")],
"coef0": [Interval(Real, None, None, closed="neither")],
"kernel_params": [dict, None],
"alpha": [Interval(Real, 0, None, closed="left")],
"fit_inverse_transform": ["boolean"],
"eigen_solver": [StrOptions({"auto", "dense", "arpack", "randomized"})],
"tol": [Interval(Real, 0, None, closed="left")],
"max_iter": [
Interval(Integral, 1, None, closed="left"),
None,
],
"iterated_power": [
Interval(Integral, 0, None, closed="left"),
StrOptions({"auto"}),
],
"remove_zero_eig": ["boolean"],
"random_state": ["random_state"],
"copy_X": ["boolean"],
"n_jobs": [None, Integral],
}

def __init__(
self,
n_components=None,
Expand Down Expand Up @@ -313,7 +347,6 @@ def _fit_transform(self, K):
if self.n_components is None:
n_components = K.shape[0] # use all dimensions
else:
check_scalar(self.n_components, "n_components", numbers.Integral, min_val=1)
n_components = min(K.shape[0], self.n_components)

# compute eigenvectors
Expand Down Expand Up @@ -343,8 +376,6 @@ def _fit_transform(self, K):
random_state=self.random_state,
selection="module",
)
else:
raise ValueError("Unsupported value for `eigen_solver`: %r" % eigen_solver)

# make sure that the eigenvalues are ok and fix numerical issues
self.eigenvalues_ = _check_psd_eigenvalues(
Expand Down Expand Up @@ -416,6 +447,8 @@ def fit(self, X, y=None):
self : object
Returns the instance itself.
"""
self._validate_params()

if self.fit_inverse_transform and self.kernel == "precomputed":
raise ValueError("Cannot fit_inverse_transform with a precomputed kernel.")
X = self._validate_data(X, accept_sparse="csr", copy=self.copy_X)
Expand Down
28 changes: 0 additions & 28 deletions sklearn/decomposition/tests/test_kernel_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@ def histogram(x, y, **kwargs):
assert X_pred2.shape == X_pred.shape


def test_kernel_pca_invalid_solver():
"""Check that kPCA raises an error if the solver parameter is invalid"""
with pytest.raises(ValueError):
KernelPCA(eigen_solver="unknown").fit(np.random.randn(10, 10))


def test_kernel_pca_invalid_parameters():
"""Check that kPCA raises an error if the parameters are invalid
Expand Down Expand Up @@ -204,16 +198,6 @@ def test_kernel_pca_n_components():
assert shape == (2, c)


@pytest.mark.parametrize("n_components", [-1, 0])
def test_kernal_pca_too_few_components(n_components):
rng = np.random.RandomState(0)
X_fit = rng.random_sample((5, 4))
kpca = KernelPCA(n_components=n_components)
msg = "n_components.* must be >= 1"
with pytest.raises(ValueError, match=msg):
kpca.fit(X_fit)


def test_remove_zero_eig():
"""Check that the ``remove_zero_eig`` parameter works correctly.
Expand Down Expand Up @@ -326,18 +310,6 @@ def test_kernel_pca_precomputed_non_symmetric(solver):
assert_array_equal(kpca.eigenvalues_, kpca_c.eigenvalues_)


def test_kernel_pca_invalid_kernel():
"""Tests that using an invalid kernel name raises a ValueError
An invalid kernel name should raise a ValueError at fit time.
"""
rng = np.random.RandomState(0)
X_fit = rng.random_sample((2, 4))
kpca = KernelPCA(kernel="tototiti")
with pytest.raises(ValueError):
kpca.fit(X_fit)


def test_gridsearch_pipeline():
"""Check that kPCA works as expected in a grid search pipeline
Expand Down
1 change: 0 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
"HashingVectorizer",
"Isomap",
"IterativeImputer",
"KernelPCA",
"LabelPropagation",
"LabelSpreading",
"Lars",
Expand Down

0 comments on commit 3312bc2

Please sign in to comment.