Skip to content

Commit

Permalink
FIX PowerTransformer raise when "box-cox" has nan column (scikit-le…
Browse files Browse the repository at this point in the history
…arn#26400)

Co-authored-by: Olivier Grisel <[email protected]>
  • Loading branch information
Charlie-XIAO and ogrisel committed Jun 13, 2023
1 parent d29f78e commit 18cf8d0
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
4 changes: 4 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,10 @@ Changelog
The `sample_interval_` attribute is deprecated and will be removed in 1.5.
:pr:`25190` by :user:`Vincent Maladière <Vincent-Maladiere>`.

- |Fix| :class:`preprocessing.PowerTransformer` now correcly raises error when
using `method="box-cox"` on data with a constant `np.nan` column.
:pr:`26400` by :user:`Yao Xiao <Charlie-XIAO>`.

:mod:`sklearn.svm`
..................

Expand Down
6 changes: 5 additions & 1 deletion sklearn/preprocessing/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3311,9 +3311,13 @@ def _box_cox_optimize(self, x):
We here use scipy builtins which uses the brent optimizer.
"""
mask = np.isnan(x)
if np.all(mask):
raise ValueError("Column must not be all nan.")

# the computation of lambda is influenced by NaNs so we need to
# get rid of them
_, lmbda = stats.boxcox(x[~np.isnan(x)], lmbda=None)
_, lmbda = stats.boxcox(x[~mask], lmbda=None)

return lmbda

Expand Down
15 changes: 15 additions & 0 deletions sklearn/preprocessing/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2527,6 +2527,21 @@ def test_power_transformer_copy_False(method, standardize):
assert X_trans is X_inv_trans


def test_power_transformer_box_cox_raise_all_nans_col():
"""Check that box-cox raises informative when a column contains all nans.
Non-regression test for gh-26303
"""
X = rng.random_sample((4, 5))
X[:, 0] = np.nan

err_msg = "Column must not be all nan."

pt = PowerTransformer(method="box-cox")
with pytest.raises(ValueError, match=err_msg):
pt.fit_transform(X)


@pytest.mark.parametrize(
"X_2",
[
Expand Down

0 comments on commit 18cf8d0

Please sign in to comment.