Skip to content

Commit

Permalink
[MRG+2] adding multilabel support for score_func (scikit-learn#7676)
Browse files Browse the repository at this point in the history
* added multilabel support for score function

* added test for multilabel score function

* updated whats_new.rst

* updated whats_new.rst with working link
  • Loading branch information
affanv14 authored and amueller committed Oct 20, 2016
1 parent 568c002 commit ee3e617
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
6 changes: 5 additions & 1 deletion doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ Bug fixes
`#6497 <https://github.com/scikit-learn/scikit-learn/pull/6497>`_
by `Sebastian Säger`_


- Fixes issue in :ref:`univariate_feature_selection` where score
functions were not accepting multi-label targets.(`#7676
<https://github.com/scikit-learn/scikit-learn/pull/7676>`_)
by `Mohammed Affan`_

.. _changes_0_18:

Version 0.18
Expand Down
15 changes: 15 additions & 0 deletions sklearn/feature_selection/tests/test_feature_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,21 @@ def test_tied_pvalues():
assert_not_in(9998, Xt)


def test_scorefunc_multilabel():
# Test whether k-best and percentiles works with multilabels with chi2.

X = np.array([[10000, 9999, 0], [100, 9999, 0], [1000, 99, 0]])
y = [[1, 1], [0, 1], [1, 0]]

Xt = SelectKBest(chi2, k=2).fit_transform(X, y)
assert_equal(Xt.shape, (3, 2))
assert_not_in(0, Xt)

Xt = SelectPercentile(chi2, percentile=67).fit_transform(X, y)
assert_equal(Xt.shape, (3, 2))
assert_not_in(0, Xt)


def test_tied_scores():
# Test for stable sorting in k-best with tied scores.
X_train = np.array([[0, 0, 0], [1, 1, 1]])
Expand Down
2 changes: 1 addition & 1 deletion sklearn/feature_selection/univariate_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def fit(self, X, y):
self : object
Returns self.
"""
X, y = check_X_y(X, y, ['csr', 'csc'])
X, y = check_X_y(X, y, ['csr', 'csc'], multi_output=True)

if not callable(self.score_func):
raise TypeError("The score function should be a callable, %s (%s) "
Expand Down

0 comments on commit ee3e617

Please sign in to comment.