Skip to content

Commit

Permalink
FEA add ValidationCurveDisplay in model_selection module (scikit-lear…
Browse files Browse the repository at this point in the history
…n#25120)

Co-authored-by: Jérémie du Boisberranger <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>
  • Loading branch information
3 people committed Jun 14, 2023
1 parent 15f7cfb commit 8cac52f
Show file tree
Hide file tree
Showing 11 changed files with 999 additions and 234 deletions.
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,7 @@ Visualization
:template: display_only_from_estimator.rst

model_selection.LearningCurveDisplay
model_selection.ValidationCurveDisplay

.. _multiclass_ref:

Expand Down
42 changes: 31 additions & 11 deletions doc/modules/learning_curve.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ The function :func:`validation_curve` can help in this case::
>>> import numpy as np
>>> from sklearn.model_selection import validation_curve
>>> from sklearn.datasets import load_iris
>>> from sklearn.linear_model import Ridge
>>> from sklearn.svm import SVC

>>> np.random.seed(0)
>>> X, y = load_iris(return_X_y=True)
Expand All @@ -80,30 +80,50 @@ The function :func:`validation_curve` can help in this case::
>>> X, y = X[indices], y[indices]

>>> train_scores, valid_scores = validation_curve(
... Ridge(), X, y, param_name="alpha", param_range=np.logspace(-7, 3, 3),
... cv=5)
... SVC(kernel="linear"), X, y, param_name="C", param_range=np.logspace(-7, 3, 3),
... )
>>> train_scores
array([[0.93..., 0.94..., 0.92..., 0.91..., 0.92...],
[0.93..., 0.94..., 0.92..., 0.91..., 0.92...],
[0.51..., 0.52..., 0.49..., 0.47..., 0.49...]])
array([[0.90..., 0.94..., 0.91..., 0.89..., 0.92...],
[0.9... , 0.92..., 0.93..., 0.92..., 0.93...],
[0.97..., 1... , 0.98..., 0.97..., 0.99...]])
>>> valid_scores
array([[0.90..., 0.84..., 0.94..., 0.96..., 0.93...],
[0.90..., 0.84..., 0.94..., 0.96..., 0.93...],
[0.46..., 0.25..., 0.50..., 0.49..., 0.52...]])
array([[0.9..., 0.9... , 0.9... , 0.96..., 0.9... ],
[0.9..., 0.83..., 0.96..., 0.96..., 0.93...],
[1.... , 0.93..., 1.... , 1.... , 0.9... ]])

If you intend to plot the validation curves only, the class
:class:`~sklearn.model_selection.ValidationCurveDisplay` is more direct than
using matplotlib manually on the results of a call to :func:`validation_curve`.
You can use the method
:meth:`~sklearn.model_selection.ValidationCurveDisplay.from_estimator` similarly
to :func:`validation_curve` to generate and plot the validation curve:

.. plot::
:context: close-figs
:align: center

from sklearn.datasets import load_iris
from sklearn.model_selection import ValidationCurveDisplay
from sklearn.svm import SVC
from sklearn.utils import shuffle
X, y = load_iris(return_X_y=True)
X, y = shuffle(X, y, random_state=0)
ValidationCurveDisplay.from_estimator(
SVC(kernel="linear"), X, y, param_name="C", param_range=np.logspace(-7, 3, 10)
)

If the training score and the validation score are both low, the estimator will
be underfitting. If the training score is high and the validation score is low,
the estimator is overfitting and otherwise it is working very well. A low
training score and a high validation score is usually not possible. Underfitting,
overfitting, and a working model are shown in the in the plot below where we vary
the parameter :math:`\gamma` of an SVM on the digits dataset.
the parameter `gamma` of an SVM with an RBF kernel on the digits dataset.

.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_validation_curve_001.png
:target: ../auto_examples/model_selection/plot_validation_curve.html
:align: center
:scale: 50%


.. _learning_curve:

Learning curve
Expand Down
1 change: 1 addition & 0 deletions doc/visualizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,4 @@ Display Objects
metrics.PredictionErrorDisplay
metrics.RocCurveDisplay
model_selection.LearningCurveDisplay
model_selection.ValidationCurveDisplay
20 changes: 20 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ random sampling procedures.
used each time the kernel is called.
:pr:`26337` by :user:`Yao Xiao <Charlie-XIAO>`.

Changed displays
----------------

- |Enhancement| :class:`model_selection.LearningCurveDisplay` displays both the
train and test curves by default. You can set `score_type="test"` to keep the
past behaviour.
:pr:`25120` by :user:`Guillaume Lemaitre <glemaitre>`.

Changes impacting all modules
-----------------------------

Expand Down Expand Up @@ -548,6 +556,18 @@ Changelog
:mod:`sklearn.model_selection`
..............................

- |MajorFeature| Added the class :class:`model_selection.ValidationCurveDisplay`
that allows easy plotting of validation curves obtained by the function
:func:`model_selection.validation_curve`.
:pr:`25120` by :user:`Guillaume Lemaitre <glemaitre>`.

- |API| The parameter `log_scale` in the class
:class:`model_selection.LearningCurveDisplay` has been deprecated in 1.3 and
will be removed in 1.5. The default scale can be overriden by setting it
directly on the `ax` object and will be set automatically from the spacing
of the data points otherwise.
:pr:`25120` by :user:`Guillaume Lemaitre <glemaitre>`.

- |Enhancement| :func:`model_selection.cross_validate` accepts a new parameter
`return_indices` to return the train-test indices of each cv split.
:pr:`25659` by :user:`Guillaume Lemaitre <glemaitre>`.
Expand Down
1 change: 1 addition & 0 deletions examples/miscellaneous/plot_kernel_ridge_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
"scoring": "neg_mean_squared_error",
"negate_score": True,
"score_name": "Mean Squared Error",
"score_type": "test",
"std_display_style": None,
"ax": ax,
}
Expand Down
46 changes: 8 additions & 38 deletions examples/model_selection/plot_validation_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,53 +18,23 @@

from sklearn.datasets import load_digits
from sklearn.svm import SVC
from sklearn.model_selection import validation_curve
from sklearn.model_selection import ValidationCurveDisplay

X, y = load_digits(return_X_y=True)
subset_mask = np.isin(y, [1, 2]) # binary classification: 1 vs 2
X, y = X[subset_mask], y[subset_mask]

param_range = np.logspace(-6, -1, 5)
train_scores, test_scores = validation_curve(
disp = ValidationCurveDisplay.from_estimator(
SVC(),
X,
y,
param_name="gamma",
param_range=param_range,
scoring="accuracy",
param_range=np.logspace(-6, -1, 5),
score_type="both",
n_jobs=2,
score_name="Accuracy",
)
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)

plt.title("Validation Curve with SVM")
plt.xlabel(r"$\gamma$")
plt.ylabel("Score")
plt.ylim(0.0, 1.1)
lw = 2
plt.semilogx(
param_range, train_scores_mean, label="Training score", color="darkorange", lw=lw
)
plt.fill_between(
param_range,
train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std,
alpha=0.2,
color="darkorange",
lw=lw,
)
plt.semilogx(
param_range, test_scores_mean, label="Cross-validation score", color="navy", lw=lw
)
plt.fill_between(
param_range,
test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std,
alpha=0.2,
color="navy",
lw=lw,
)
plt.legend(loc="best")
disp.ax_.set_title("Validation Curve for SVM with an RBF kernel")
disp.ax_.set_xlabel(r"gamma (inverse radius of the RBF kernel)")
disp.ax_.set_ylim(0.0, 1.1)
plt.show()
2 changes: 2 additions & 0 deletions sklearn/model_selection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ._search import ParameterSampler

from ._plot import LearningCurveDisplay
from ._plot import ValidationCurveDisplay

if typing.TYPE_CHECKING:
# Avoid errors in type checkers (e.g. mypy) for experimental estimators.
Expand Down Expand Up @@ -74,6 +75,7 @@
"permutation_test_score",
"train_test_split",
"validation_curve",
"ValidationCurveDisplay",
]


Expand Down
Loading

0 comments on commit 8cac52f

Please sign in to comment.