diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 11add6151f91b..68cf826c43fec 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -20,15 +20,19 @@ Version 1.5.1 **TODO** -Changelog ---------- - Changes impacting many modules ------------------------------ +- |Fix| Fixed a regression in the validation of the input data of all estimators where + an unexpected error was raised when passing a DataFrame backed by a read-only buffer. + :pr:`29018` by :user:`Jérémie du Boisberranger `. + - |Fix| Fixed a regression causing a dead-lock at import time in some settings. :pr:`29235` by :user:`Jérémie du Boisberranger `. +Changelog +--------- + :mod:`sklearn.metrics` ...................... @@ -37,6 +41,10 @@ Changes impacting many modules instead of implicitly converting those inputs as regular NumPy arrays. :pr:`29119` by :user:`Olivier Grisel`. +- |Fix| Fix a regression in :func:`metrics.zero_one_loss` causing an error + for Array API dispatch with multilabel inputs. + :pr:`29269` by :user:`Yaroslav Korobko `. + :mod:`sklearn.model_selection` .............................. @@ -48,12 +56,14 @@ Changes impacting many modules grids that have estimators as parameter values. :pr:`29179` by :user:`Marco Gorelli`. -:mod:`sklearn.metrics` -.............................. +:mod:`sklearn.utils` +.................... -- |Fix| Fix a regression in :func:`metrics.zero_one_loss` causing an error - for Array API dispatch with multilabel inputs. - :pr:`29269` by :user:`Yaroslav Korobko `. +- |API| :func:`utils.validation.check_array` has a new parameter, `force_writeable`, to + control the writeability of the output array. If set to `True`, the output array will + be guaranteed to be writeable and a copy will be made if the input array is read-only. + If set to `False`, no guarantee is made about the writeability of the output array. + :pr:`29018` by :user:`Jérémie du Boisberranger `. .. _changes_1_5: diff --git a/sklearn/cluster/_affinity_propagation.py b/sklearn/cluster/_affinity_propagation.py index f68fa0522f6ff..fa5a3513ed899 100644 --- a/sklearn/cluster/_affinity_propagation.py +++ b/sklearn/cluster/_affinity_propagation.py @@ -502,13 +502,10 @@ def fit(self, X, y=None): Returns the instance itself. """ if self.affinity == "precomputed": - accept_sparse = False - else: - accept_sparse = "csr" - X = self._validate_data(X, accept_sparse=accept_sparse) - if self.affinity == "precomputed": - self.affinity_matrix_ = X.copy() if self.copy else X + X = self._validate_data(X, copy=self.copy, force_writeable=True) + self.affinity_matrix_ = X else: # self.affinity == "euclidean" + X = self._validate_data(X, accept_sparse="csr") self.affinity_matrix_ = -euclidean_distances(X, squared=True) if self.affinity_matrix_.shape[0] != self.affinity_matrix_.shape[1]: diff --git a/sklearn/cluster/_hdbscan/hdbscan.py b/sklearn/cluster/_hdbscan/hdbscan.py index 9933318313cc8..d20e745309fca 100644 --- a/sklearn/cluster/_hdbscan/hdbscan.py +++ b/sklearn/cluster/_hdbscan/hdbscan.py @@ -770,6 +770,7 @@ def fit(self, X, y=None): X, accept_sparse=["csr", "lil"], dtype=np.float64, + force_writeable=True, ) else: # Only non-sparse, precomputed distance matrices are handled here @@ -777,7 +778,9 @@ def fit(self, X, y=None): # Perform data validation after removing infinite values (numpy.inf) # from the given distance matrix. - X = self._validate_data(X, force_all_finite=False, dtype=np.float64) + X = self._validate_data( + X, force_all_finite=False, dtype=np.float64, force_writeable=True + ) if np.isnan(X).any(): # TODO: Support np.nan in Cython implementation for precomputed # dense HDBSCAN diff --git a/sklearn/cross_decomposition/_pls.py b/sklearn/cross_decomposition/_pls.py index 143149b1bb4db..16024cf961d27 100644 --- a/sklearn/cross_decomposition/_pls.py +++ b/sklearn/cross_decomposition/_pls.py @@ -263,10 +263,19 @@ def fit(self, X, y=None, Y=None): check_consistent_length(X, y) X = self._validate_data( - X, dtype=np.float64, copy=self.copy, ensure_min_samples=2 + X, + dtype=np.float64, + force_writeable=True, + copy=self.copy, + ensure_min_samples=2, ) y = check_array( - y, input_name="y", dtype=np.float64, copy=self.copy, ensure_2d=False + y, + input_name="y", + dtype=np.float64, + force_writeable=True, + copy=self.copy, + ensure_2d=False, ) if y.ndim == 1: self._predict_1d = True @@ -1056,10 +1065,19 @@ def fit(self, X, y=None, Y=None): y = _deprecate_Y_when_required(y, Y) check_consistent_length(X, y) X = self._validate_data( - X, dtype=np.float64, copy=self.copy, ensure_min_samples=2 + X, + dtype=np.float64, + force_writeable=True, + copy=self.copy, + ensure_min_samples=2, ) y = check_array( - y, input_name="y", dtype=np.float64, copy=self.copy, ensure_2d=False + y, + input_name="y", + dtype=np.float64, + force_writeable=True, + copy=self.copy, + ensure_2d=False, ) if y.ndim == 1: y = y.reshape(-1, 1) diff --git a/sklearn/decomposition/_factor_analysis.py b/sklearn/decomposition/_factor_analysis.py index 2164feb67aa26..df45606fe3de4 100644 --- a/sklearn/decomposition/_factor_analysis.py +++ b/sklearn/decomposition/_factor_analysis.py @@ -216,7 +216,9 @@ def fit(self, X, y=None): self : object FactorAnalysis class instance. """ - X = self._validate_data(X, copy=self.copy, dtype=np.float64) + X = self._validate_data( + X, copy=self.copy, dtype=np.float64, force_writeable=True + ) n_samples, n_features = X.shape n_components = self.n_components diff --git a/sklearn/decomposition/_incremental_pca.py b/sklearn/decomposition/_incremental_pca.py index fea5b952a262a..8b345a797e452 100644 --- a/sklearn/decomposition/_incremental_pca.py +++ b/sklearn/decomposition/_incremental_pca.py @@ -228,6 +228,7 @@ def fit(self, X, y=None): accept_sparse=["csr", "csc", "lil"], copy=self.copy, dtype=[np.float64, np.float32], + force_writeable=True, ) n_samples, n_features = X.shape @@ -277,7 +278,11 @@ def partial_fit(self, X, y=None, check_input=True): "or use IncrementalPCA.fit to do so in batches." ) X = self._validate_data( - X, copy=self.copy, dtype=[np.float64, np.float32], reset=first_pass + X, + copy=self.copy, + dtype=[np.float64, np.float32], + force_writeable=True, + reset=first_pass, ) n_samples, n_features = X.shape if first_pass: diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 51d01a5781720..ffbf42d32b2bc 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -505,6 +505,7 @@ def _fit(self, X): X = self._validate_data( X, dtype=[xp.float64, xp.float32], + force_writeable=True, accept_sparse=("csr", "csc"), ensure_2d=True, copy=False, diff --git a/sklearn/impute/_base.py b/sklearn/impute/_base.py index 5a053f6b3ddfe..6109e3fde7b2a 100644 --- a/sklearn/impute/_base.py +++ b/sklearn/impute/_base.py @@ -333,6 +333,7 @@ def _validate_input(self, X, in_fit): reset=in_fit, accept_sparse="csc", dtype=dtype, + force_writeable=True if not in_fit else None, force_all_finite=force_all_finite, copy=self.copy, ) diff --git a/sklearn/impute/_knn.py b/sklearn/impute/_knn.py index 5ac7216fc8188..2e18246b4b9bb 100644 --- a/sklearn/impute/_knn.py +++ b/sklearn/impute/_knn.py @@ -269,6 +269,7 @@ def transform(self, X): X, accept_sparse=False, dtype=FLOAT_DTYPES, + force_writeable=True, force_all_finite=force_all_finite, copy=self.copy, reset=False, diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index 20de9a61fe788..0ca59d97948bc 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -598,7 +598,12 @@ def fit(self, X, y, sample_weight=None): accept_sparse = False if self.positive else ["csr", "csc", "coo"] X, y = self._validate_data( - X, y, accept_sparse=accept_sparse, y_numeric=True, multi_output=True + X, + y, + accept_sparse=accept_sparse, + y_numeric=True, + multi_output=True, + force_writeable=True, ) has_sw = sample_weight is not None diff --git a/sklearn/linear_model/_bayes.py b/sklearn/linear_model/_bayes.py index e87ea5320c6c1..dfdcdf23599c9 100644 --- a/sklearn/linear_model/_bayes.py +++ b/sklearn/linear_model/_bayes.py @@ -235,7 +235,9 @@ def fit(self, X, y, sample_weight=None): self : object Returns the instance itself. """ - X, y = self._validate_data(X, y, dtype=[np.float64, np.float32], y_numeric=True) + X, y = self._validate_data( + X, y, dtype=[np.float64, np.float32], force_writeable=True, y_numeric=True + ) dtype = X.dtype if sample_weight is not None: @@ -620,7 +622,12 @@ def fit(self, X, y): Fitted estimator. """ X, y = self._validate_data( - X, y, dtype=[np.float64, np.float32], y_numeric=True, ensure_min_samples=2 + X, + y, + dtype=[np.float64, np.float32], + force_writeable=True, + y_numeric=True, + ensure_min_samples=2, ) dtype = X.dtype diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 1b29db27f16f8..c23527de9e07b 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -979,6 +979,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): accept_sparse="csc", order="F", dtype=[np.float64, np.float32], + force_writeable=True, accept_large_sparse=False, copy=X_copied, multi_output=True, @@ -1607,6 +1608,7 @@ def fit(self, X, y, sample_weight=None, **params): check_X_params = dict( accept_sparse="csc", dtype=[np.float64, np.float32], + force_writeable=True, copy=False, accept_large_sparse=False, ) @@ -1632,6 +1634,7 @@ def fit(self, X, y, sample_weight=None, **params): accept_sparse="csc", dtype=[np.float64, np.float32], order="F", + force_writeable=True, copy=copy_X, ) X, y = self._validate_data( @@ -2508,6 +2511,7 @@ def fit(self, X, y): check_X_params = dict( dtype=[np.float64, np.float32], order="F", + force_writeable=True, copy=self.copy_X and self.fit_intercept, ) check_y_params = dict(ensure_2d=False, order="F") diff --git a/sklearn/linear_model/_least_angle.py b/sklearn/linear_model/_least_angle.py index 3090f6d147ad6..378010c7cdb58 100644 --- a/sklearn/linear_model/_least_angle.py +++ b/sklearn/linear_model/_least_angle.py @@ -1177,7 +1177,9 @@ def fit(self, X, y, Xy=None): self : object Returns an instance of self. """ - X, y = self._validate_data(X, y, y_numeric=True, multi_output=True) + X, y = self._validate_data( + X, y, force_writeable=True, y_numeric=True, multi_output=True + ) alpha = getattr(self, "alpha", 0.0) if hasattr(self, "n_nonzero_coefs"): @@ -1718,7 +1720,7 @@ def fit(self, X, y, **params): """ _raise_for_params(params, self, "fit") - X, y = self._validate_data(X, y, y_numeric=True) + X, y = self._validate_data(X, y, force_writeable=True, y_numeric=True) X = as_float_array(X, copy=self.copy_X) y = as_float_array(y, copy=self.copy_X) @@ -2235,7 +2237,7 @@ def fit(self, X, y, copy_X=None): """ if copy_X is None: copy_X = self.copy_X - X, y = self._validate_data(X, y, y_numeric=True) + X, y = self._validate_data(X, y, force_writeable=True, y_numeric=True) X, y, Xmean, ymean, Xstd = _preprocess_data( X, y, fit_intercept=self.fit_intercept, copy=copy_X diff --git a/sklearn/linear_model/_ridge.py b/sklearn/linear_model/_ridge.py index ac8e094a88ad0..c9143389739af 100644 --- a/sklearn/linear_model/_ridge.py +++ b/sklearn/linear_model/_ridge.py @@ -1241,6 +1241,7 @@ def fit(self, X, y, sample_weight=None): y, accept_sparse=_accept_sparse, dtype=[xp.float64, xp.float32], + force_writeable=True, multi_output=True, y_numeric=True, ) @@ -1290,6 +1291,7 @@ def _prepare_data(self, X, y, sample_weight, solver): accept_sparse=accept_sparse, multi_output=True, y_numeric=False, + force_writeable=True, ) self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1) diff --git a/sklearn/preprocessing/_data.py b/sklearn/preprocessing/_data.py index d1415e0ff71d2..7e7d8a8dd3c17 100644 --- a/sklearn/preprocessing/_data.py +++ b/sklearn/preprocessing/_data.py @@ -529,6 +529,7 @@ def transform(self, X): X, copy=self.copy, dtype=_array_api.supported_float_dtypes(xp), + force_writeable=True, force_all_finite="allow-nan", reset=False, ) @@ -560,6 +561,7 @@ def inverse_transform(self, X): X, copy=self.copy, dtype=_array_api.supported_float_dtypes(xp), + force_writeable=True, force_all_finite="allow-nan", ) @@ -1040,6 +1042,7 @@ def transform(self, X, copy=None): accept_sparse="csr", copy=copy, dtype=FLOAT_DTYPES, + force_writeable=True, force_all_finite="allow-nan", ) @@ -1081,6 +1084,7 @@ def inverse_transform(self, X, copy=None): accept_sparse="csr", copy=copy, dtype=FLOAT_DTYPES, + force_writeable=True, force_all_finite="allow-nan", ) @@ -1285,6 +1289,7 @@ def transform(self, X): copy=self.copy, reset=False, dtype=_array_api.supported_float_dtypes(xp), + force_writeable=True, force_all_finite="allow-nan", ) @@ -1316,6 +1321,7 @@ def inverse_transform(self, X): accept_sparse=("csr", "csc"), copy=self.copy, dtype=_array_api.supported_float_dtypes(xp), + force_writeable=True, force_all_finite="allow-nan", ) @@ -1648,6 +1654,7 @@ def transform(self, X): accept_sparse=("csr", "csc"), copy=self.copy, dtype=FLOAT_DTYPES, + force_writeable=True, reset=False, force_all_finite="allow-nan", ) @@ -1681,6 +1688,7 @@ def inverse_transform(self, X): accept_sparse=("csr", "csc"), copy=self.copy, dtype=FLOAT_DTYPES, + force_writeable=True, force_all_finite="allow-nan", ) @@ -1922,6 +1930,7 @@ def normalize(X, norm="l2", *, axis=1, copy=True, return_norm=False): copy=copy, estimator="the normalize function", dtype=_array_api.supported_float_dtypes(xp), + force_writeable=True, ) if axis == 0: X = X.T @@ -2085,8 +2094,10 @@ def transform(self, X, copy=None): Transformed array. """ copy = copy if copy is not None else self.copy - X = self._validate_data(X, accept_sparse="csr", reset=False) - return normalize(X, norm=self.norm, axis=1, copy=copy) + X = self._validate_data( + X, accept_sparse="csr", force_writeable=True, copy=copy, reset=False + ) + return normalize(X, norm=self.norm, axis=1, copy=False) def _more_tags(self): return {"stateless": True, "array_api_support": True} @@ -2140,7 +2151,7 @@ def binarize(X, *, threshold=0.0, copy=True): array([[0., 1., 0.], [1., 0., 0.]]) """ - X = check_array(X, accept_sparse=["csr", "csc"], copy=copy) + X = check_array(X, accept_sparse=["csr", "csc"], force_writeable=True, copy=copy) if sparse.issparse(X): if threshold < 0: raise ValueError("Cannot binarize a sparse matrix with threshold < 0") @@ -2281,7 +2292,13 @@ def transform(self, X, copy=None): copy = copy if copy is not None else self.copy # TODO: This should be refactored because binarize also calls # check_array - X = self._validate_data(X, accept_sparse=["csr", "csc"], copy=copy, reset=False) + X = self._validate_data( + X, + accept_sparse=["csr", "csc"], + force_writeable=True, + copy=copy, + reset=False, + ) return binarize(X, threshold=self.threshold, copy=False) def _more_tags(self): @@ -2842,6 +2859,9 @@ def _check_inputs(self, X, in_fit, accept_sparse_negative=False, copy=False): accept_sparse="csc", copy=copy, dtype=FLOAT_DTYPES, + # only set force_writeable for the validation at transform time because + # it's the only place where QuantileTransformer performs inplace operations. + force_writeable=True if not in_fit else None, force_all_finite="allow-nan", ) # we only accept positive sparse matrix when ignore_implicit_zeros is @@ -3480,6 +3500,7 @@ def _check_input(self, X, in_fit, check_positive=False, check_shape=False): X, ensure_2d=True, dtype=FLOAT_DTYPES, + force_writeable=True, copy=self.copy, force_all_finite="allow-nan", reset=in_fit, diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 4f5cf3ae3e62c..47af38a563a77 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -79,6 +79,7 @@ check_get_feature_names_out_error, check_global_output_transform_pandas, check_global_set_output_transform_polars, + check_inplace_ensure_writeable, check_n_features_in_after_fitting, check_param_validation, check_set_output_transform, @@ -617,3 +618,31 @@ def test_set_output_transform_configured(estimator, check_func): _set_checking_parameters(estimator) with ignore_warnings(category=(FutureWarning)): check_func(estimator.__class__.__name__, estimator) + + +@pytest.mark.parametrize( + "estimator", _tested_estimators(), ids=_get_check_estimator_ids +) +def test_check_inplace_ensure_writeable(estimator): + name = estimator.__class__.__name__ + + if hasattr(estimator, "copy"): + estimator.set_params(copy=False) + elif hasattr(estimator, "copy_X"): + estimator.set_params(copy_X=False) + else: + raise SkipTest(f"{name} doesn't require writeable input.") + + _set_checking_parameters(estimator) + + # The following estimators can work inplace only with certain settings + if name == "HDBSCAN": + estimator.set_params(metric="precomputed", algorithm="brute") + + if name == "PCA": + estimator.set_params(svd_solver="full") + + if name == "KernelPCA": + estimator.set_params(kernel="precomputed") + + check_inplace_ensure_writeable(name, estimator) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 5ba1540094588..2108d33d6ad77 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4722,3 +4722,49 @@ def check_set_output_transform_polars(name, transformer_orig): def check_global_set_output_transform_polars(name, transformer_orig): _check_set_output_transform_polars_context(name, transformer_orig, "global") + + +@ignore_warnings(category=FutureWarning) +def check_inplace_ensure_writeable(name, estimator_orig): + """Check that estimators able to do inplace operations can work on read-only + input data even if a copy is not explicitly requested by the user. + + Make sure that a copy is made and consequently that the input array and its + writeability are not modified by the estimator. + """ + rng = np.random.RandomState(0) + + estimator = clone(estimator_orig) + set_random_state(estimator) + + n_samples = 100 + + X, _ = make_blobs(n_samples=n_samples, n_features=3, random_state=rng) + X = _enforce_estimator_tags_X(estimator, X) + + # These estimators can only work inplace with fortran ordered input + if name in ("Lasso", "ElasticNet", "MultiTaskElasticNet", "MultiTaskLasso"): + X = np.asfortranarray(X) + + # Add a missing value for imputers so that transform has to do something + if hasattr(estimator, "missing_values"): + X[0, 0] = np.nan + + if is_regressor(estimator): + y = rng.normal(size=n_samples) + else: + y = rng.randint(low=0, high=2, size=n_samples) + y = _enforce_estimator_tags_y(estimator, y) + + X_copy = X.copy() + + # Make X read-only + X.setflags(write=False) + + estimator.fit(X, y) + + if hasattr(estimator, "transform"): + estimator.transform(X) + + assert not X.flags.writeable + assert_allclose(X, X_copy) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 92fff950e875e..5bde51ae514d9 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -46,6 +46,7 @@ assert_allclose_dense_sparse, assert_array_equal, assert_no_warnings, + create_memmap_backed_data, ignore_warnings, skip_if_array_api_compat_not_configured, ) @@ -2124,3 +2125,67 @@ def __init__(self): self.schema = ["a", "b"] assert not _is_polars_df(LooksLikePolars()) + + +def test_check_array_writeable_np(): + """Check the behavior of check_array when a writeable array is requested + without copy if possible, on numpy arrays. + """ + X = np.random.uniform(size=(10, 10)) + + out = check_array(X, copy=False, force_writeable=True) + # X is already writeable, no copy is needed + assert np.may_share_memory(out, X) + assert out.flags.writeable + + X.flags.writeable = False + + out = check_array(X, copy=False, force_writeable=True) + # X is not writeable, a copy is made + assert not np.may_share_memory(out, X) + assert out.flags.writeable + + +def test_check_array_writeable_mmap(): + """Check the behavior of check_array when a writeable array is requested + without copy if possible, on a memory-map. + + A common situation is when a meta-estimators run in parallel using multiprocessing + with joblib, which creates read-only memory-maps of large arrays. + """ + X = np.random.uniform(size=(10, 10)) + + mmap = create_memmap_backed_data(X, mmap_mode="w+") + out = check_array(mmap, copy=False, force_writeable=True) + # mmap is already writeable, no copy is needed + assert np.may_share_memory(out, mmap) + assert out.flags.writeable + + mmap = create_memmap_backed_data(X, mmap_mode="r") + out = check_array(mmap, copy=False, force_writeable=True) + # mmap is read-only, a copy is made + assert not np.may_share_memory(out, mmap) + assert out.flags.writeable + + +def test_check_array_writeable_df(): + """Check the behavior of check_array when a writeable array is requested + without copy if possible, on a dataframe. + """ + pd = pytest.importorskip("pandas") + + X = np.random.uniform(size=(10, 10)) + df = pd.DataFrame(X, copy=False) + + out = check_array(df, copy=False, force_writeable=True) + # df is backed by a writeable array, no copy is needed + assert np.may_share_memory(out, df) + assert out.flags.writeable + + X.flags.writeable = False + df = pd.DataFrame(X, copy=False) + + out = check_array(df, copy=False, force_writeable=True) + # df is backed by a read-only array, a copy is made + assert not np.may_share_memory(out, df) + assert out.flags.writeable diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index d632abb77280d..228fbe76a25e1 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -717,6 +717,7 @@ def check_array( dtype="numeric", order=None, copy=False, + force_writeable=False, force_all_finite=True, ensure_2d=True, allow_nd=False, @@ -767,6 +768,13 @@ def check_array( Whether a forced copy will be triggered. If copy=False, a copy might be triggered by a conversion. + force_writeable : bool, default=False + Whether to force the output array to be writeable. If True, the returned array + is guaranteed to be writeable, which may require a copy. Otherwise the + writeability of the input array is preserved. + + .. versionadded:: 1.6 + force_all_finite : bool or 'allow-nan', default=True Whether to raise an error on np.inf, np.nan, pd.NA in array. The possibilities are: @@ -1085,17 +1093,32 @@ def is_sparse(dtype): % (n_features, array.shape, ensure_min_features, context) ) - # With an input pandas dataframe or series, we know we can always make the - # resulting array writeable: - # - if copy=True, we have already made a copy so it is fine to make the - # array writeable - # - if copy=False, the caller is telling us explicitly that we can do - # in-place modifications - # See https://pandas.pydata.org/docs/dev/user_guide/copy_on_write.html#read-only-numpy-arrays - # for more details about pandas copy-on-write mechanism, that is enabled by - # default in pandas 3.0.0.dev. - if _is_pandas_df_or_series(array_orig) and hasattr(array, "flags"): - array.flags.writeable = True + if force_writeable: + # By default, array.copy() creates a C-ordered copy. We set order=K to + # preserve the order of the array. + copy_params = {"order": "K"} if not sp.issparse(array) else {} + + array_data = array.data if sp.issparse(array) else array + flags = getattr(array_data, "flags", None) + if not getattr(flags, "writeable", True): + # This situation can only happen when copy=False, the array is read-only and + # a writeable output is requested. This is an ambiguous setting so we chose + # to always (except for one specific setting, see below) make a copy to + # ensure that the output is writeable, even if avoidable, to not overwrite + # the user's data by surprise. + + if _is_pandas_df_or_series(array_orig): + try: + # In pandas >= 3, np.asarray(df), called earlier in check_array, + # returns a read-only intermediate array. It can be made writeable + # safely without copy because if the original DataFrame was backed + # by a read-only array, trying to change the flag would raise an + # error, in which case we make a copy. + array_data.flags.writeable = True + except ValueError: + array = array.copy(**copy_params) + else: + array = array.copy(**copy_params) return array @@ -1131,6 +1154,7 @@ def check_X_y( dtype="numeric", order=None, copy=False, + force_writeable=False, force_all_finite=True, ensure_2d=True, allow_nd=False, @@ -1185,6 +1209,13 @@ def check_X_y( Whether a forced copy will be triggered. If copy=False, a copy might be triggered by a conversion. + force_writeable : bool, default=False + Whether to force the output array to be writeable. If True, the returned array + is guaranteed to be writeable, which may require a copy. Otherwise the + writeability of the input array is preserved. + + .. versionadded:: 1.6 + force_all_finite : bool or 'allow-nan', default=True Whether to raise an error on np.inf, np.nan, pd.NA in X. This parameter does not influence whether y can have np.inf, np.nan, pd.NA values. @@ -1268,6 +1299,7 @@ def check_X_y( dtype=dtype, order=order, copy=copy, + force_writeable=force_writeable, force_all_finite=force_all_finite, ensure_2d=ensure_2d, allow_nd=allow_nd,