Skip to content

Commit

Permalink
MAINT Parameter validation for descendants of BaseLibSVM (scikit-lear…
Browse files Browse the repository at this point in the history
…n#24001)

Co-authored-by: Stefanie Molin <[email protected]>
Co-authored-by: jeremiedbb <[email protected]>
  • Loading branch information
3 people committed Jul 27, 2022
1 parent 4f315db commit 329adf1
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 113 deletions.
10 changes: 8 additions & 2 deletions sklearn/ensemble/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Testing for the gradient boosting module (sklearn.ensemble.gradient_boosting).
"""
import re
import warnings
import numpy as np
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -1233,9 +1234,14 @@ def test_gradient_boosting_with_init_pipeline():
# sure we make the distinction between ValueError raised by a pipeline that
# was passed sample_weight, and a ValueError raised by a regular estimator
# whose input checking failed.
with pytest.raises(ValueError, match="nu <= 0 or nu > 1"):
invalid_nu = 1.5
err_msg = (
"The 'nu' parameter of NuSVR must be a float in the"
f" range (0.0, 1.0]. Got {invalid_nu} instead."
)
with pytest.raises(ValueError, match=re.escape(err_msg)):
# Note that NuSVR properly supports sample_weight
init = NuSVR(gamma="auto", nu=1.5)
init = NuSVR(gamma="auto", nu=invalid_nu)
gb = GradientBoostingRegressor(init=init)
gb.fit(X, y, sample_weight=np.ones(X.shape[0]))

Expand Down
63 changes: 36 additions & 27 deletions sklearn/svm/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
import numbers
from abc import ABCMeta, abstractmethod
from numbers import Integral, Real

import numpy as np
import scipy.sparse as sp
Expand All @@ -22,6 +22,7 @@
from ..utils.validation import _num_samples
from ..utils.validation import _check_sample_weight, check_consistent_length
from ..utils.multiclass import check_classification_targets
from ..utils._param_validation import Interval, StrOptions
from ..exceptions import ConvergenceWarning
from ..exceptions import NotFittedError

Expand Down Expand Up @@ -69,6 +70,30 @@ class BaseLibSVM(BaseEstimator, metaclass=ABCMeta):
Parameter documentation is in the derived `SVC` class.
"""

_parameter_constraints = {
"kernel": [
StrOptions({"linear", "poly", "rbf", "sigmoid", "precomputed"}),
callable,
],
"degree": [Interval(Integral, 0, None, closed="left")],
"gamma": [
StrOptions({"scale", "auto"}),
Interval(Real, 0.0, None, closed="neither"),
],
"coef0": [Interval(Real, None, None, closed="neither")],
"tol": [Interval(Real, 0.0, None, closed="neither")],
"C": [Interval(Real, 0.0, None, closed="neither")],
"nu": [Interval(Real, 0.0, 1.0, closed="right")],
"epsilon": [Interval(Real, 0.0, None, closed="left")],
"shrinking": ["boolean"],
"probability": ["boolean"],
"cache_size": [Interval(Real, 0, None, closed="neither")],
"class_weight": [StrOptions({"balanced"}), dict, None],
"verbose": ["verbose"],
"max_iter": [Interval(Integral, -1, None, closed="left")],
"random_state": ["random_state"],
}

# The order of these must match the integer values in LibSVM.
# XXX These are actually the same in the dense case. Need to factor
# this out.
Expand Down Expand Up @@ -152,6 +177,7 @@ def fit(self, X, y, sample_weight=None):
If X is a dense array, then the other methods will not support sparse
matrices as input.
"""
self._validate_params()

rnd = check_random_state(self.random_state)

Expand All @@ -160,13 +186,6 @@ def fit(self, X, y, sample_weight=None):
raise TypeError("Sparse precomputed kernels are not supported.")
self._sparse = sparse and not callable(self.kernel)

if hasattr(self, "decision_function_shape"):
if self.decision_function_shape not in ("ovr", "ovo"):
raise ValueError(
"decision_function_shape must be either 'ovr' or 'ovo', "
f"got {self.decision_function_shape}."
)

if callable(self.kernel):
check_consistent_length(X, y)
else:
Expand Down Expand Up @@ -222,26 +241,8 @@ def fit(self, X, y, sample_weight=None):
self._gamma = 1.0 / (X.shape[1] * X_var) if X_var != 0 else 1.0
elif self.gamma == "auto":
self._gamma = 1.0 / X.shape[1]
else:
raise ValueError(
"When 'gamma' is a string, it should be either 'scale' or "
f"'auto'. Got '{self.gamma!r}' instead."
)
elif isinstance(self.gamma, numbers.Real):
if self.gamma <= 0:
msg = (
f"gamma value must be > 0; {self.gamma!r} is invalid. Use"
" a positive number or use 'auto' to set gamma to a"
" value of 1 / n_features."
)
raise ValueError(msg)
elif isinstance(self.gamma, Real):
self._gamma = self.gamma
else:
msg = (
"The gamma value should be set to 'scale', 'auto' or a"
f" positive float value. {self.gamma!r} is not a valid option"
)
raise ValueError(msg)

fit = self._sparse_fit if self._sparse else self._dense_fit
if self.verbose:
Expand Down Expand Up @@ -691,6 +692,14 @@ def n_support_(self):
class BaseSVC(ClassifierMixin, BaseLibSVM, metaclass=ABCMeta):
"""ABC for LibSVM-based classifiers."""

_parameter_constraints = {
**BaseLibSVM._parameter_constraints, # type: ignore
"decision_function_shape": [StrOptions({"ovr", "ovo"})],
"break_ties": ["boolean"],
}
for unused_param in ["epsilon", "nu"]:
_parameter_constraints.pop(unused_param)

@abstractmethod
def __init__(
self,
Expand Down
45 changes: 34 additions & 11 deletions sklearn/svm/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,14 +587,15 @@ class SVC(BaseSVC):
degree : int, default=3
Degree of the polynomial kernel function ('poly').
Ignored by all other kernels.
Must be non-negative. Ignored by all other kernels.
gamma : {'scale', 'auto'} or float, default='scale'
Kernel coefficient for 'rbf', 'poly' and 'sigmoid'.
- if ``gamma='scale'`` (default) is passed then it uses
1 / (n_features * X.var()) as value of gamma,
- if 'auto', uses 1 / n_features.
- if 'auto', uses 1 / n_features
- if float, must be non-negative.
.. versionchanged:: 0.22
The default value of ``gamma`` changed from 'auto' to 'scale'.
Expand Down Expand Up @@ -850,14 +851,15 @@ class NuSVC(BaseSVC):
degree : int, default=3
Degree of the polynomial kernel function ('poly').
Ignored by all other kernels.
Must be non-negative. Ignored by all other kernels.
gamma : {'scale', 'auto'} or float, default='scale'
Kernel coefficient for 'rbf', 'poly' and 'sigmoid'.
- if ``gamma='scale'`` (default) is passed then it uses
1 / (n_features * X.var()) as value of gamma,
- if 'auto', uses 1 / n_features.
- if 'auto', uses 1 / n_features
- if float, must be non-negative.
.. versionchanged:: 0.22
The default value of ``gamma`` changed from 'auto' to 'scale'.
Expand Down Expand Up @@ -1037,6 +1039,12 @@ class NuSVC(BaseSVC):

_impl = "nu_svc"

_parameter_constraints = {
**BaseSVC._parameter_constraints, # type: ignore
"nu": [Interval(Real, 0.0, 1.0, closed="right")],
}
_parameter_constraints.pop("C")

def __init__(
self,
*,
Expand Down Expand Up @@ -1114,14 +1122,15 @@ class SVR(RegressorMixin, BaseLibSVM):
degree : int, default=3
Degree of the polynomial kernel function ('poly').
Ignored by all other kernels.
Must be non-negative. Ignored by all other kernels.
gamma : {'scale', 'auto'} or float, default='scale'
Kernel coefficient for 'rbf', 'poly' and 'sigmoid'.
- if ``gamma='scale'`` (default) is passed then it uses
1 / (n_features * X.var()) as value of gamma,
- if 'auto', uses 1 / n_features.
- if 'auto', uses 1 / n_features
- if float, must be non-negative.
.. versionchanged:: 0.22
The default value of ``gamma`` changed from 'auto' to 'scale'.
Expand All @@ -1142,7 +1151,7 @@ class SVR(RegressorMixin, BaseLibSVM):
Epsilon in the epsilon-SVR model. It specifies the epsilon-tube
within which no penalty is associated in the training loss function
with points predicted within a distance epsilon from the actual
value.
value. Must be non-negative.
shrinking : bool, default=True
Whether to use the shrinking heuristic.
Expand Down Expand Up @@ -1247,6 +1256,10 @@ class SVR(RegressorMixin, BaseLibSVM):

_impl = "epsilon_svr"

_parameter_constraints = {**BaseLibSVM._parameter_constraints} # type: ignore
for unused_param in ["class_weight", "nu", "probability", "random_state"]:
_parameter_constraints.pop(unused_param)

def __init__(
self,
*,
Expand Down Expand Up @@ -1329,14 +1342,15 @@ class NuSVR(RegressorMixin, BaseLibSVM):
degree : int, default=3
Degree of the polynomial kernel function ('poly').
Ignored by all other kernels.
Must be non-negative. Ignored by all other kernels.
gamma : {'scale', 'auto'} or float, default='scale'
Kernel coefficient for 'rbf', 'poly' and 'sigmoid'.
- if ``gamma='scale'`` (default) is passed then it uses
1 / (n_features * X.var()) as value of gamma,
- if 'auto', uses 1 / n_features.
- if 'auto', uses 1 / n_features
- if float, must be non-negative.
.. versionchanged:: 0.22
The default value of ``gamma`` changed from 'auto' to 'scale'.
Expand Down Expand Up @@ -1451,6 +1465,10 @@ class NuSVR(RegressorMixin, BaseLibSVM):

_impl = "nu_svr"

_parameter_constraints = {**BaseLibSVM._parameter_constraints} # type: ignore
for unused_param in ["class_weight", "epsilon", "probability", "random_state"]:
_parameter_constraints.pop(unused_param)

def __init__(
self,
*,
Expand Down Expand Up @@ -1523,14 +1541,15 @@ class OneClassSVM(OutlierMixin, BaseLibSVM):
degree : int, default=3
Degree of the polynomial kernel function ('poly').
Ignored by all other kernels.
Must be non-negative. Ignored by all other kernels.
gamma : {'scale', 'auto'} or float, default='scale'
Kernel coefficient for 'rbf', 'poly' and 'sigmoid'.
- if ``gamma='scale'`` (default) is passed then it uses
1 / (n_features * X.var()) as value of gamma,
- if 'auto', uses 1 / n_features.
- if 'auto', uses 1 / n_features
- if float, must be non-negative.
.. versionchanged:: 0.22
The default value of ``gamma`` changed from 'auto' to 'scale'.
Expand Down Expand Up @@ -1645,6 +1664,10 @@ class OneClassSVM(OutlierMixin, BaseLibSVM):

_impl = "one_class"

_parameter_constraints = {**BaseLibSVM._parameter_constraints} # type: ignore
for unused_param in ["C", "class_weight", "epsilon", "probability", "random_state"]:
_parameter_constraints.pop(unused_param)

def __init__(
self,
*,
Expand Down
70 changes: 2 additions & 68 deletions sklearn/svm/tests/test_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,6 @@ def test_decision_function_shape(SVM):
dec = clf.decision_function(X_train)
assert dec.shape == (len(X_train), 10)

with pytest.raises(ValueError, match="must be either 'ovr' or 'ovo'"):
SVM(decision_function_shape="bad").fit(X_train, y_train)


def test_svr_predict():
# Test SVR's decision_function
Expand Down Expand Up @@ -685,19 +682,10 @@ def test_auto_weight():


def test_bad_input():
# Test that it gives proper exception on deficient input
# impossible value of C
with pytest.raises(ValueError):
svm.SVC(C=-1).fit(X, Y)

# impossible value of nu
clf = svm.NuSVC(nu=0.0)
with pytest.raises(ValueError):
clf.fit(X, Y)

# Test dimensions for labels
Y2 = Y[:-1] # wrong dimensions for labels
with pytest.raises(ValueError):
clf.fit(X, Y2)
svm.SVC().fit(X, Y2)

# Test with arrays that are non-contiguous.
for clf in (svm.SVC(), svm.LinearSVC(random_state=0)):
Expand Down Expand Up @@ -745,60 +733,6 @@ def test_svc_nonfinite_params():
clf.fit(X, y)


@pytest.mark.parametrize(
"Estimator, data",
[
(svm.SVC, datasets.load_iris(return_X_y=True)),
(svm.NuSVC, datasets.load_iris(return_X_y=True)),
(svm.SVR, datasets.load_diabetes(return_X_y=True)),
(svm.NuSVR, datasets.load_diabetes(return_X_y=True)),
(svm.OneClassSVM, datasets.load_iris(return_X_y=True)),
],
)
@pytest.mark.parametrize(
"gamma, err_msg",
[
(
"auto_deprecated",
"When 'gamma' is a string, it should be either 'scale' or 'auto'",
),
(
-1,
"gamma value must be > 0; -1 is invalid. Use"
" a positive number or use 'auto' to set gamma to a"
" value of 1 / n_features.",
),
(
0.0,
"gamma value must be > 0; 0.0 is invalid. Use"
" a positive number or use 'auto' to set gamma to a"
" value of 1 / n_features.",
),
(
np.array([1.0, 4.0]),
"The gamma value should be set to 'scale',"
f" 'auto' or a positive float value. {np.array([1.0, 4.0])!r}"
" is not a valid option",
),
(
[],
"The gamma value should be set to 'scale', 'auto' or a positive"
f" float value. {[]} is not a valid option",
),
(
{},
"The gamma value should be set to 'scale', 'auto' or a positive"
" float value. {} is not a valid option",
),
],
)
def test_svm_gamma_error(Estimator, data, gamma, err_msg):
X, y = data
est = Estimator(gamma=gamma)
with pytest.raises(ValueError, match=(re.escape(err_msg))):
est.fit(X, y)


def test_unicode_kernel():
# Test that a unicode kernel name does not cause a TypeError
clf = svm.SVC(kernel="linear", probability=True)
Expand Down
5 changes: 0 additions & 5 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,12 +473,9 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
"MultiTaskElasticNet",
"MultiTaskLasso",
"NeighborhoodComponentsAnalysis",
"NuSVC",
"NuSVR",
"Nystroem",
"OAS",
"OPTICS",
"OneClassSVM",
"OneVsOneClassifier",
"OneVsRestClassifier",
"PatchExtractor",
Expand All @@ -492,8 +489,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
"RegressorChain",
"RidgeCV",
"RidgeClassifierCV",
"SVC",
"SVR",
"SelectFdr",
"SelectFpr",
"SelectFromModel",
Expand Down

0 comments on commit 329adf1

Please sign in to comment.