Skip to content

Commit

Permalink
[MRG+1] Added unit test for adding classes_ property to GridSearchCV, f…
Browse files Browse the repository at this point in the history
…ixes scikit-learn#6298 (scikit-learn#7661)

* Fix issue scikit-learn#6298
Adds a "classes_" property to BaseSearchCV

* Added test to ensure classes_ property is added to gridSearch correctly

* Fixed formatting

* Added test to ensure gridSearchCV with a regressor does not have a classes_ attribute

* Fixed whitespace issues

* Combined tests for the new GridSearchSV.classes_ property into a single test.

* Removed trailing whitespace

* Added what's new for pull request scikit-learn#7661

* Fixed formatting of update in what's new
  • Loading branch information
abatula authored and amueller committed Oct 20, 2016
1 parent 30b936d commit 707b6f9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ New features
Enhancements
............

- Added ``classes_`` attribute to :class:`model_selection.GridSearchCV`
that matches the ``classes_`` attribute of ``best_estimator_``. (`#7661
<https://github.com/scikit-learn/scikit-learn/pull/7661>`_) by `Alyssa
Batula`_ and `Dylan Werner-Meier`_.

- The ``min_weight_fraction_leaf`` constraint in tree construction is now
more efficient, taking a fast path to declare a node a leaf if its weight
is less than 2 * the minimum. Note that the constructed tree will be
Expand Down
8 changes: 6 additions & 2 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ def __init__(self, estimator, scoring=None,
def _estimator_type(self):
return self.estimator._estimator_type

@property
def classes_(self):
return self.best_estimator_.classes_

def score(self, X, y=None):
"""Returns the score on the given data, if the estimator has been refit.
Expand Down Expand Up @@ -688,7 +692,7 @@ class GridSearchCV(BaseSearchCV):
- An iterable yielding train/test splits.
For integer/None inputs, if the estimator is a classifier and ``y`` is
either binary or multiclass,
either binary or multiclass,
:class:`sklearn.model_selection.StratifiedKFold` is used. In all
other cases, :class:`sklearn.model_selection.KFold` is used.
Expand Down Expand Up @@ -900,7 +904,7 @@ class RandomizedSearchCV(BaseSearchCV):
- An iterable yielding train/test splits.
For integer/None inputs, if the estimator is a classifier and ``y`` is
either binary or multiclass,
either binary or multiclass,
:class:`sklearn.model_selection.StratifiedKFold` is used. In all
other cases, :class:`sklearn.model_selection.KFold` is used.
Expand Down
18 changes: 18 additions & 0 deletions sklearn/tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from sklearn.metrics import f1_score
from sklearn.metrics import make_scorer
from sklearn.metrics import roc_auc_score
from sklearn.linear_model import Ridge

from sklearn.exceptions import ChangedBehaviorWarning
from sklearn.exceptions import FitFailedWarning
Expand Down Expand Up @@ -785,3 +786,20 @@ def test_parameters_sampler_replacement():
sampler = ParameterSampler(params_distribution, n_iter=7)
samples = list(sampler)
assert_equal(len(samples), 7)


def test_classes__property():
# Test that classes_ property matches best_esimator_.classes_
X = np.arange(100).reshape(10, 10)
y = np.array([0] * 5 + [1] * 5)
Cs = [.1, 1, 10]

grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs})
grid_search.fit(X, y)
assert_array_equal(grid_search.best_estimator_.classes_,
grid_search.classes_)

# Test that regressors do not have a classes_ attribute
grid_search = GridSearchCV(Ridge(), {'alpha': [1.0, 2.0]})
grid_search.fit(X, y)
assert_false(hasattr(grid_search, 'classes_'))

0 comments on commit 707b6f9

Please sign in to comment.