Skip to content

Commit

Permalink
FEA SLEP006: Metadata routing for validation_curve (scikit-learn#29329
Browse files Browse the repository at this point in the history
)

Co-authored-by: Adam Li <[email protected]>
  • Loading branch information
StefanieSenger and adam2392 committed Jul 5, 2024
1 parent 64ab789 commit e1cf244
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 27 deletions.
2 changes: 1 addition & 1 deletion doc/metadata_routing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ Meta-estimators and functions supporting metadata routing:
- :func:`sklearn.model_selection.cross_val_score`
- :func:`sklearn.model_selection.cross_val_predict`
- :class:`sklearn.model_selection.learning_curve`
- :class:`sklearn.model_selection.validation_curve`
- :class:`sklearn.multiclass.OneVsOneClassifier`
- :class:`sklearn.multiclass.OneVsRestClassifier`
- :class:`sklearn.multiclass.OutputCodeClassifier`
Expand All @@ -323,5 +324,4 @@ Meta-estimators and tools not supporting metadata routing yet:
- :class:`sklearn.feature_selection.RFECV`
- :class:`sklearn.feature_selection.SequentialFeatureSelector`
- :class:`sklearn.model_selection.permutation_test_score`
- :class:`sklearn.model_selection.validation_curve`
- :class:`sklearn.semi_supervised.SelfTrainingClassifier`
4 changes: 4 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ more details.
params to the underlying regressor.
:pr:`29136` by :user:`Omar Salman <OmarManzoor>`.

- |Feature| :func:`model_selection.validation_curve` now supports metadata routing for
the `fit` method of its estimator and for its underlying CV splitter and scorer.
:pr:`29329` by :user:`Stefanie Senger <StefanieSenger>`.

Dropping official support for PyPy
----------------------------------

Expand Down
83 changes: 77 additions & 6 deletions sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,7 +1855,7 @@ def learning_curve(
Parameters to pass to the fit method of the estimator.
.. deprecated:: 1.6
This parameter is deprecated and will be removed in version 1.6. Use
This parameter is deprecated and will be removed in version 1.8. Use
``params`` instead.
params : dict, default=None
Expand Down Expand Up @@ -2221,6 +2221,7 @@ def _incremental_fit_estimator(
"verbose": ["verbose"],
"error_score": [StrOptions({"raise"}), Real],
"fit_params": [dict, None],
"params": [dict, None],
},
prefer_skip_nested_validation=False, # estimator is not validated yet
)
Expand All @@ -2239,6 +2240,7 @@ def validation_curve(
verbose=0,
error_score=np.nan,
fit_params=None,
params=None,
):
"""Validation curve.
Expand Down Expand Up @@ -2277,6 +2279,13 @@ def validation_curve(
train/test set. Only used in conjunction with a "Group" :term:`cv`
instance (e.g., :class:`GroupKFold`).
.. versionchanged:: 1.6
``groups`` can only be passed if metadata routing is not enabled
via ``sklearn.set_config(enable_metadata_routing=True)``. When routing
is enabled, pass ``groups`` alongside other metadata via the ``params``
argument instead. E.g.:
``validation_curve(..., params={'groups': groups})``.
cv : int, cross-validation generator or an iterable, default=None
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
Expand Down Expand Up @@ -2327,7 +2336,22 @@ def validation_curve(
fit_params : dict, default=None
Parameters to pass to the fit method of the estimator.
.. versionadded:: 0.24
.. deprecated:: 1.6
This parameter is deprecated and will be removed in version 1.8. Use
``params`` instead.
params : dict, default=None
Parameters to pass to the estimator, scorer and cross-validation object.
- If `enable_metadata_routing=False` (default):
Parameters directly passed to the `fit` method of the estimator.
- If `enable_metadata_routing=True`:
Parameters safely routed to the `fit` method of the estimator, to the
scorer and to the cross-validation object. See :ref:`Metadata Routing User
Guide <metadata_routing>` for more details.
.. versionadded:: 1.6
Returns
-------
Expand Down Expand Up @@ -2358,11 +2382,59 @@ def validation_curve(
>>> print(f"The average test accuracy is {test_scores.mean():.2f}")
The average test accuracy is 0.81
"""
params = _check_params_groups_deprecation(fit_params, params, groups, "1.8")
X, y, groups = indexable(X, y, groups)

cv = check_cv(cv, y, classifier=is_classifier(estimator))
scorer = check_scoring(estimator, scoring=scoring)

if _routing_enabled():
router = (
MetadataRouter(owner="validation_curve")
.add(
estimator=estimator,
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
)
.add(
splitter=cv,
method_mapping=MethodMapping().add(caller="fit", callee="split"),
)
.add(
scorer=scorer,
method_mapping=MethodMapping().add(caller="fit", callee="score"),
)
)

try:
routed_params = process_routing(router, "fit", **params)
except UnsetMetadataPassedError as e:
# The default exception would mention `fit` since in the above
# `process_routing` code, we pass `fit` as the caller. However,
# the user is not calling `fit` directly, so we change the message
# to make it more suitable for this case.
unrequested_params = sorted(e.unrequested_params)
raise UnsetMetadataPassedError(
message=(
f"{unrequested_params} are passed to `validation_curve` but are not"
" explicitly set as requested or not requested for"
f" validation_curve's estimator: {estimator.__class__.__name__}."
" Call `.set_fit_request({{metadata}}=True)` on the estimator for"
f" each metadata in {unrequested_params} that you"
" want to use and `metadata=False` for not using it. See the"
" Metadata Routing User guide"
" <https://scikit-learn.org/stable/metadata_routing.html> for more"
" information."
),
unrequested_params=e.unrequested_params,
routed_params=e.routed_params,
)

else:
routed_params = Bunch()
routed_params.estimator = Bunch(fit=params)
routed_params.splitter = Bunch(split={"groups": groups})
routed_params.scorer = Bunch(score={})

parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch, verbose=verbose)
results = parallel(
delayed(_fit_and_score)(
Expand All @@ -2374,14 +2446,13 @@ def validation_curve(
test=test,
verbose=verbose,
parameters={param_name: v},
fit_params=fit_params,
# TODO(SLEP6): support score params here
score_params=None,
fit_params=routed_params.estimator.fit,
score_params=routed_params.scorer.score,
return_train_score=True,
error_score=error_score,
)
# NOTE do not change order of iteration to allow one time cv splitters
for train, test in cv.split(X, y, groups)
for train, test in cv.split(X, y, **routed_params.splitter.split)
for v in param_range
)
n_params = len(param_range)
Expand Down
81 changes: 61 additions & 20 deletions sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,7 +1697,7 @@ def test_validation_curve_cv_splits_consistency():
assert_array_almost_equal(np.array(scores3), np.array(scores1))


def test_validation_curve_fit_params():
def test_validation_curve_params():
X = np.arange(100).reshape(10, 10)
y = np.array([0] * 5 + [1] * 5)
clf = CheckingClassifier(expected_sample_weight=True)
Expand All @@ -1722,7 +1722,7 @@ def test_validation_curve_fit_params():
param_name="foo_param",
param_range=[1, 2, 3],
error_score="raise",
fit_params={"sample_weight": np.ones(1)},
params={"sample_weight": np.ones(1)},
)
validation_curve(
clf,
Expand All @@ -1731,7 +1731,7 @@ def test_validation_curve_fit_params():
param_name="foo_param",
param_range=[1, 2, 3],
error_score="raise",
fit_params={"sample_weight": np.ones(10)},
params={"sample_weight": np.ones(10)},
)


Expand Down Expand Up @@ -2482,29 +2482,54 @@ def test_cross_validate_return_indices(global_random_seed):
assert_array_equal(test_indices[split_idx], expected_test_idx)


# Tests for metadata routing in cross_val* and learning_curve
# ===========================================================
# Tests for metadata routing in cross_val* and in *curve
# ======================================================


# TODO(1.6): remove `cross_validate` and `cross_val_predict` from this test in 1.6 and
# `learning_curve` in 1.8
@pytest.mark.parametrize("func", [cross_validate, cross_val_predict, learning_curve])
def test_fit_param_deprecation(func):
# `learning_curve` and `validation_curve` in 1.8
@pytest.mark.parametrize(
"func, extra_args",
[
(cross_validate, {}),
(cross_val_score, {}),
(cross_val_predict, {}),
(learning_curve, {}),
(validation_curve, {"param_name": "alpha", "param_range": np.array([1])}),
],
)
def test_fit_param_deprecation(func, extra_args):
"""Check that we warn about deprecating `fit_params`."""
with pytest.warns(FutureWarning, match="`fit_params` is deprecated"):
func(estimator=ConsumingClassifier(), X=X, y=y, cv=2, fit_params={})
func(
estimator=ConsumingClassifier(), X=X, y=y, cv=2, fit_params={}, **extra_args
)

with pytest.raises(
ValueError, match="`params` and `fit_params` cannot both be provided"
):
func(estimator=ConsumingClassifier(), X=X, y=y, fit_params={}, params={})
func(
estimator=ConsumingClassifier(),
X=X,
y=y,
fit_params={},
params={},
**extra_args,
)


@pytest.mark.usefixtures("enable_slep006")
@pytest.mark.parametrize(
"func", [cross_validate, cross_val_score, cross_val_predict, learning_curve]
"func, extra_args",
[
(cross_validate, {}),
(cross_val_score, {}),
(cross_val_predict, {}),
(learning_curve, {}),
(validation_curve, {"param_name": "alpha", "param_range": np.array([1])}),
],
)
def test_groups_with_routing_validation(func):
def test_groups_with_routing_validation(func, extra_args):
"""Check that we raise an error if `groups` are passed to the cv method instead
of `params` when metadata routing is enabled.
"""
Expand All @@ -2514,14 +2539,22 @@ def test_groups_with_routing_validation(func):
X=X,
y=y,
groups=[],
**extra_args,
)


@pytest.mark.usefixtures("enable_slep006")
@pytest.mark.parametrize(
"func", [cross_validate, cross_val_score, cross_val_predict, learning_curve]
"func, extra_args",
[
(cross_validate, {}),
(cross_val_score, {}),
(cross_val_predict, {}),
(learning_curve, {}),
(validation_curve, {"param_name": "alpha", "param_range": np.array([1])}),
],
)
def test_passed_unrequested_metadata(func):
def test_passed_unrequested_metadata(func, extra_args):
"""Check that we raise an error when passing metadata that is not
requested."""
err_msg = re.escape("but are not explicitly set as requested or not requested")
Expand All @@ -2531,14 +2564,22 @@ def test_passed_unrequested_metadata(func):
X=X,
y=y,
params=dict(metadata=[]),
**extra_args,
)


@pytest.mark.usefixtures("enable_slep006")
@pytest.mark.parametrize(
"func", [cross_validate, cross_val_score, cross_val_predict, learning_curve]
"func, extra_args",
[
(cross_validate, {}),
(cross_val_score, {}),
(cross_val_predict, {}),
(learning_curve, {}),
(validation_curve, {"param_name": "alpha", "param_range": np.array([1])}),
],
)
def test_validation_functions_routing(func):
def test_validation_functions_routing(func, extra_args):
"""Check that the respective cv method is properly dispatching the metadata
to the consumer."""
scorer_registry = _Registry()
Expand All @@ -2563,12 +2604,11 @@ def test_validation_functions_routing(func):
fit_sample_weight = rng.rand(n_samples)
fit_metadata = rng.rand(n_samples)

extra_params = {
scoring_args = {
cross_validate: dict(scoring=dict(my_scorer=scorer, accuracy="accuracy")),
# cross_val_score and learning_curve don't support multiple scorers:
cross_val_score: dict(scoring=scorer),
learning_curve: dict(scoring=scorer),
# cross_val_predict doesn't need a scorer
validation_curve: dict(scoring=scorer),
cross_val_predict: dict(),
}

Expand All @@ -2590,7 +2630,8 @@ def test_validation_functions_routing(func):
X=X,
y=y,
cv=splitter,
**extra_params[func],
**scoring_args[func],
**extra_args,
params=params,
)

Expand Down

0 comments on commit e1cf244

Please sign in to comment.