Skip to content

Commit

Permalink
[MRG + 1] FIX raise an error message when n_groups > number of groups (
Browse files Browse the repository at this point in the history
…scikit-learn#7681) (scikit-learn#7683)

* FIX raise an error message when n_groups > actual number of groups (scikit-learn#7681)

This change addresses issue scikit-learn#7681:
- Raise ValueError when n_groups > actual number of unique groups in LeaveOneGroupOut and LeavePGroupsOut.
- Add unit test.

* Make requested changes

- Check error message with `assert_raise_message`
- Pass parameters to `assert_raise_message` instead of defining functions

* Update condition and exception message
  • Loading branch information
polmauri authored and amueller committed Oct 25, 2016
1 parent ff5c36e commit 73d3f03
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
10 changes: 10 additions & 0 deletions sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,10 @@ def _iter_test_masks(self, X, y, groups):
# We make a copy of groups to avoid side-effects during iteration
groups = np.array(groups, copy=True)
unique_groups = np.unique(groups)
if len(unique_groups) <= 1:
raise ValueError(
"The groups parameter contains fewer than 2 unique groups "
"(%s). LeaveOneGroupOut expects at least 2." % unique_groups)
for i in unique_groups:
yield groups == i

Expand Down Expand Up @@ -862,6 +866,12 @@ def _iter_test_masks(self, X, y, groups):
raise ValueError("The groups parameter should not be None")
groups = np.array(groups, copy=True)
unique_groups = np.unique(groups)
if self.n_groups >= len(unique_groups):
raise ValueError(
"The groups parameter contains fewer than (or equal to) "
"n_groups (%d) numbers of unique groups (%s). LeavePGroupsOut "
"expects that at least n_groups + 1 (%d) unique groups be "
"present" % (self.n_groups, unique_groups, self.n_groups + 1))
combi = combinations(range(len(unique_groups)), self.n_groups)
for indices in combi:
test_index = np.zeros(_num_samples(X), dtype=np.bool)
Expand Down
25 changes: 25 additions & 0 deletions sklearn/model_selection/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,31 @@ def test_leave_group_out_changing_groups():
assert_equal(3, LeaveOneGroupOut().get_n_splits(X, y, groups))


def test_leave_one_p_group_out_error_on_fewer_number_of_groups():
X = y = groups = np.ones(0)
msg = ("The groups parameter contains fewer than 2 unique groups ([]). "
"LeaveOneGroupOut expects at least 2.")
assert_raise_message(ValueError, msg, next,
LeaveOneGroupOut().split(X, y, groups))
X = y = groups = np.ones(1)
msg = ("The groups parameter contains fewer than 2 unique groups ([ 1.]). "
"LeaveOneGroupOut expects at least 2.")
assert_raise_message(ValueError, msg, next,
LeaveOneGroupOut().split(X, y, groups))
X = y = groups = np.ones(1)
msg = ("The groups parameter contains fewer than (or equal to) n_groups "
"(3) numbers of unique groups ([ 1.]). LeavePGroupsOut expects "
"that at least n_groups + 1 (4) unique groups be present")
assert_raise_message(ValueError, msg, next,
LeavePGroupsOut(n_groups=3).split(X, y, groups))
X = y = groups = np.arange(3)
msg = ("The groups parameter contains fewer than (or equal to) n_groups "
"(3) numbers of unique groups ([0 1 2]). LeavePGroupsOut expects "
"that at least n_groups + 1 (4) unique groups be present")
assert_raise_message(ValueError, msg, next,
LeavePGroupsOut(n_groups=3).split(X, y, groups))


def test_train_test_split_errors():
assert_raises(ValueError, train_test_split)
assert_raises(ValueError, train_test_split, range(3), train_size=1.1)
Expand Down

0 comments on commit 73d3f03

Please sign in to comment.