Skip to content

Commit

Permalink
MAINT Parameters validation for sklearn.model_selection.cross_val_pre…
Browse files Browse the repository at this point in the history
…dict (scikit-learn#26252)

Co-authored-by: jeremie du boisberranger <[email protected]>
  • Loading branch information
dmitrylala and jeremiedbb committed Jun 28, 2023
1 parent 845771f commit 7f871fe
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
31 changes: 28 additions & 3 deletions sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,30 @@ def _score(estimator, X_test, y_test, scorer, error_score="raise"):
return scores


@validate_params(
{
"estimator": [HasMethods(["fit", "predict"])],
"X": ["array-like", "sparse matrix"],
"y": ["array-like", None],
"groups": ["array-like", None],
"cv": ["cv_object"],
"n_jobs": [Integral, None],
"verbose": ["verbose"],
"fit_params": [dict, None],
"pre_dispatch": [Integral, str, None],
"method": [
StrOptions(
{
"predict",
"predict_proba",
"predict_log_proba",
"decision_function",
}
)
],
},
prefer_skip_nested_validation=False, # estimator is not validated yet
)
def cross_val_predict(
estimator,
X,
Expand Down Expand Up @@ -912,10 +936,11 @@ def cross_val_predict(
Parameters
----------
estimator : estimator object implementing 'fit' and 'predict'
The object to use to fit the data.
estimator : estimator
The estimator instance to use to fit the data. It must implement a `fit`
method and the method given by the `method` parameter.
X : array-like of shape (n_samples, n_features)
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The data to fit. Can be, for example a list, or an array at least 2d.
y : array-like of shape (n_samples,) or (n_samples, n_outputs), \
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def _check_function_param_validation(
"sklearn.metrics.top_k_accuracy_score",
"sklearn.metrics.v_measure_score",
"sklearn.metrics.zero_one_loss",
"sklearn.model_selection.cross_val_predict",
"sklearn.model_selection.cross_val_score",
"sklearn.model_selection.cross_validate",
"sklearn.model_selection.learning_curve",
Expand Down

0 comments on commit 7f871fe

Please sign in to comment.