Skip to content

Commit

Permalink
FIX DummyRegressor overriding constant (scikit-learn#22486)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan committed Feb 15, 2022
1 parent 5d1571d commit b28c5bb
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
6 changes: 6 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,12 @@ Changelog
:class:`discriminant_analysis.LinearDiscriminantAnalysis`. :pr:`22120` by
`Thomas Fan`_.

:mod:`sklearn.dummy`
....................

- |Fix| :class:`dummy.DummyRegressor` no longer overrides the `constant`
parameter during `fit`. :pr:`22486` by `Thomas Fan`_.

:mod:`sklearn.ensemble`
.......................

Expand Down
6 changes: 2 additions & 4 deletions sklearn/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,20 +606,18 @@ def fit(self, X, y, sample_weight=None):
"when the constant strategy is used."
)

self.constant = check_array(
self.constant_ = check_array(
self.constant,
accept_sparse=["csr", "csc", "coo"],
ensure_2d=False,
ensure_min_samples=0,
)

if self.n_outputs_ != 1 and self.constant.shape[0] != y.shape[1]:
if self.n_outputs_ != 1 and self.constant_.shape[0] != y.shape[1]:
raise ValueError(
"Constant target value should have shape (%d, 1)." % y.shape[1]
)

self.constant_ = self.constant

self.constant_ = np.reshape(self.constant_, (1, -1))
return self

Expand Down
3 changes: 3 additions & 0 deletions sklearn/tests/test_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,9 @@ def test_constant_strategy_regressor():
reg.fit(X, y)
assert_array_equal(reg.predict(X), [43] * len(X))

# non-regression test for #22478
assert not isinstance(reg.constant, np.ndarray)


def test_constant_strategy_multioutput_regressor():

Expand Down

0 comments on commit b28c5bb

Please sign in to comment.