Skip to content

Commit

Permalink
[MRG + 1] Move n_iter and get_params invariance tests to common estim…
Browse files Browse the repository at this point in the history
…ator_checks (scikit-learn#7677)

* Test get_params invariance in common estimator tests

Remove test_get_params_invariance() from `test_common.py` and add
test call to _yield_all_tests() in `estimator_checks.py` to make
sure that get_params(deep=False) of a given Estimator returns a
subset of get_params(deep=True).

Compared to test_get_params_invariance(), it is NOT tested anymore
whether the given Estimator has an attribute get_params since
class BaseEstimator in `base.py` defines such an attribute
for each Estimator.

Partially addresses issue scikit-learn#7533
Also related to issue scikit-learn#4465

* Move test_transformer_n_iter() to estimator_checks.py

Remove the test test_transformer_n_iter() from tests/test_common.py
and perform the test logic in utils/estimator_checks.py instead.
Specifically, the method _yield_transformer_checks() now yields
check_transformer_n_iter() as part of the set of tests for
transformers.

test_transformer_n_iter() tests that that transformers with an
attribute max_iter, return the attribute of n_iter at least 1.

Partially addresses latter part of issue scikit-learn#7533

* Move test_non_transformer_estimators_n_iter() to estimator_checks.py

Remove the test_non_transformer_estimators_n_iter() from
tests/test_common.py; perform the test logic in
utils/estimator_checks.py instead.
Specifically, the method _yield_non_meta_checks() now yields
check_non_transformer_estimators_n_iter().

test_transformer_n_iter() tests that that estimators that are not
transformers with an attribute max_iter, return the attribute n_iter
of at least 1.

NOTE: The current implementation makes said test run for more
estimators than before this commit.
For some of these estimators, the test fails. This needs to be addressed
(see FIXME in line 111-115 of utils/estimator_checks.py for a potential
place to start).

Partially addresses latter part of issue scikit-learn#7533

* Fix check_non_transformer_estimators_n_iter calls

test_transformer_n_iter() test is now only run for
estimators where the test is applicable.

Partially addresses latter part of issue scikit-learn#7533

* Run check_non_transformer_estimators_n_iter on multi-class estimators

To do this, use helper method multioutput_estimator_convert_y_2d.
Also remove multi_output parameter from
check_non_transformer_estimators_n_iter since this parameter is not
used anywhere and corresponding cases should be handled by said
helper method.

Also, some pep8 line length fixes.

* Fix documentation for n_iter tests

There was some confusion between attributes and parameters.
Also rename n_iter to n_iter_
  • Loading branch information
JungeAlexander authored and amueller committed Oct 20, 2016
1 parent 61a9a9b commit 568c002
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 108 deletions.
75 changes: 1 addition & 74 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@
from sklearn.linear_model.base import LinearClassifierMixin
from sklearn.utils.estimator_checks import (
_yield_all_checks,
CROSS_DECOMPOSITION,
check_parameters_default_constructible,
check_class_weight_balanced_linear_classifier,
check_transformer_n_iter,
check_non_transformer_estimators_n_iter,
check_get_params_invariance)
check_class_weight_balanced_linear_classifier)


def test_all_estimator_no_base_class():
Expand Down Expand Up @@ -162,72 +158,3 @@ def test_all_tests_are_importable():
'{0} do not have `tests` subpackages. Perhaps they require '
'__init__.py or an add_subpackage directive in the parent '
'setup.py'.format(missing_tests))


def test_non_transformer_estimators_n_iter():
# Test that all estimators of type which are non-transformer
# and which have an attribute of max_iter, return the attribute
# of n_iter atleast 1.
for est_type in ['regressor', 'classifier', 'cluster']:
regressors = all_estimators(type_filter=est_type)
for name, Estimator in regressors:
# LassoLars stops early for the default alpha=1.0 for
# the iris dataset.
if name == 'LassoLars':
estimator = Estimator(alpha=0.)
else:
with ignore_warnings(category=DeprecationWarning):
estimator = Estimator()
if hasattr(estimator, "max_iter"):
# These models are dependent on external solvers like
# libsvm and accessing the iter parameter is non-trivial.
if name in (['Ridge', 'SVR', 'NuSVR', 'NuSVC',
'RidgeClassifier', 'SVC', 'RandomizedLasso',
'LogisticRegressionCV']):
continue

# Tested in test_transformer_n_iter below
elif (name in CROSS_DECOMPOSITION or
name in ['LinearSVC', 'LogisticRegression']):
continue

else:
# Multitask models related to ENet cannot handle
# if y is mono-output.
yield (_named_check(
check_non_transformer_estimators_n_iter, name),
name, estimator, 'Multi' in name)


def test_transformer_n_iter():
transformers = all_estimators(type_filter='transformer')
for name, Estimator in transformers:
with ignore_warnings(category=DeprecationWarning):
estimator = Estimator()
# Dependent on external solvers and hence accessing the iter
# param is non-trivial.
external_solver = ['Isomap', 'KernelPCA', 'LocallyLinearEmbedding',
'RandomizedLasso', 'LogisticRegressionCV']

if hasattr(estimator, "max_iter") and name not in external_solver:
yield _named_check(
check_transformer_n_iter, name), name, estimator


def test_get_params_invariance():
# Test for estimators that support get_params, that
# get_params(deep=False) is a subset of get_params(deep=True)
# Related to issue #4465

estimators = all_estimators(include_meta_estimators=False,
include_other=True)
for name, Estimator in estimators:
if hasattr(Estimator, 'get_params'):
# If class is deprecated, ignore deprecated warnings
if hasattr(Estimator.__init__, "deprecated_original"):
with ignore_warnings():
yield _named_check(
check_get_params_invariance, name), name, Estimator
else:
yield _named_check(
check_get_params_invariance, name), name, Estimator
102 changes: 68 additions & 34 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def _yield_classifier_checks(name, Classifier):
if 'class_weight' in Classifier().get_params().keys():
yield check_class_weight_classifiers

yield check_non_transformer_estimators_n_iter


@ignore_warnings(category=DeprecationWarning)
def check_supervised_y_no_nan(name, Estimator):
Expand Down Expand Up @@ -172,6 +174,7 @@ def _yield_regressor_checks(name, Regressor):
if name != "GaussianProcessRegressor":
# Test if NotFittedError is raised
yield check_estimators_unfitted
yield check_non_transformer_estimators_n_iter


def _yield_transformer_checks(name, Transformer):
Expand All @@ -186,6 +189,13 @@ def _yield_transformer_checks(name, Transformer):
# basic tests
yield check_transformer_general
yield check_transformers_unfitted
# Dependent on external solvers and hence accessing the iter
# param is non-trivial.
external_solver = ['Isomap', 'KernelPCA', 'LocallyLinearEmbedding',
'RandomizedLasso', 'LogisticRegressionCV']
if name not in external_solver:
yield check_transformer_n_iter



def _yield_clustering_checks(name, Clusterer):
Expand All @@ -195,6 +205,7 @@ def _yield_clustering_checks(name, Clusterer):
# let's not test that here.
yield check_clustering
yield check_estimators_partial_fit_n_features
yield check_non_transformer_estimators_n_iter


def _yield_all_checks(name, Estimator):
Expand All @@ -218,6 +229,7 @@ def _yield_all_checks(name, Estimator):
yield check_fit2d_1feature
yield check_fit1d_1feature
yield check_fit1d_1sample
yield check_get_params_invariance


def check_estimator(Estimator):
Expand Down Expand Up @@ -1477,51 +1489,73 @@ def multioutput_estimator_convert_y_2d(name, y):


@ignore_warnings(category=DeprecationWarning)
def check_non_transformer_estimators_n_iter(name, estimator,
multi_output=False):
# Check if all iterative solvers, run for more than one iteration

iris = load_iris()
X, y_ = iris.data, iris.target

if multi_output:
y_ = np.reshape(y_, (-1, 1))
def check_non_transformer_estimators_n_iter(name, Estimator):
# Test that estimators that are not transformers with a parameter
# max_iter, return the attribute of n_iter_ at least 1.

# These models are dependent on external solvers like
# libsvm and accessing the iter parameter is non-trivial.
not_run_check_n_iter = ['Ridge', 'SVR', 'NuSVR', 'NuSVC',
'RidgeClassifier', 'SVC', 'RandomizedLasso',
'LogisticRegressionCV', 'LinearSVC',
'LogisticRegression']

# Tested in test_transformer_n_iter
not_run_check_n_iter += CROSS_DECOMPOSITION
if name in not_run_check_n_iter:
return

set_random_state(estimator, 0)
if name == 'AffinityPropagation':
estimator.fit(X)
# LassoLars stops early for the default alpha=1.0 the iris dataset.
if name == 'LassoLars':
estimator = Estimator(alpha=0.)
else:
estimator.fit(X, y_)
estimator = Estimator()
if hasattr(estimator, 'max_iter'):
iris = load_iris()
X, y_ = iris.data, iris.target
y_ = multioutput_estimator_convert_y_2d(name, y_)

set_random_state(estimator, 0)
if name == 'AffinityPropagation':
estimator.fit(X)
else:
estimator.fit(X, y_)

# HuberRegressor depends on scipy.optimize.fmin_l_bfgs_b
# which doesn't return a n_iter for old versions of SciPy.
if not (name == 'HuberRegressor' and estimator.n_iter_ is None):
assert_greater_equal(estimator.n_iter_, 1)
# HuberRegressor depends on scipy.optimize.fmin_l_bfgs_b
# which doesn't return a n_iter for old versions of SciPy.
if not (name == 'HuberRegressor' and estimator.n_iter_ is None):
assert_greater_equal(estimator.n_iter_, 1)


@ignore_warnings(category=DeprecationWarning)
def check_transformer_n_iter(name, estimator):
if name in CROSS_DECOMPOSITION:
# Check using default data
X = [[0., 0., 1.], [1., 0., 0.], [2., 2., 2.], [2., 5., 4.]]
y_ = [[0.1, -0.2], [0.9, 1.1], [0.1, -0.5], [0.3, -0.2]]
def check_transformer_n_iter(name, Estimator):
# Test that transformers with a parameter max_iter, return the
# attribute of n_iter_ at least 1.
estimator = Estimator()
if hasattr(estimator, "max_iter"):
if name in CROSS_DECOMPOSITION:
# Check using default data
X = [[0., 0., 1.], [1., 0., 0.], [2., 2., 2.], [2., 5., 4.]]
y_ = [[0.1, -0.2], [0.9, 1.1], [0.1, -0.5], [0.3, -0.2]]

else:
X, y_ = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
random_state=0, n_features=2, cluster_std=0.1)
X -= X.min() - 0.1
set_random_state(estimator, 0)
estimator.fit(X, y_)
else:
X, y_ = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]],
random_state=0, n_features=2, cluster_std=0.1)
X -= X.min() - 0.1
set_random_state(estimator, 0)
estimator.fit(X, y_)

# These return a n_iter per component.
if name in CROSS_DECOMPOSITION:
for iter_ in estimator.n_iter_:
assert_greater_equal(iter_, 1)
else:
assert_greater_equal(estimator.n_iter_, 1)
# These return a n_iter per component.
if name in CROSS_DECOMPOSITION:
for iter_ in estimator.n_iter_:
assert_greater_equal(iter_, 1)
else:
assert_greater_equal(estimator.n_iter_, 1)


@ignore_warnings(category=DeprecationWarning)
def check_get_params_invariance(name, estimator):
# Checks if get_params(deep=False) is a subset of get_params(deep=True)
class T(BaseEstimator):
"""Mock classifier
"""
Expand Down

0 comments on commit 568c002

Please sign in to comment.