Skip to content

Commit

Permalink
FIX validate parameter in 'fit' for 'FactorAnalysis' (scikit-learn#21713
Browse files Browse the repository at this point in the history
)
  • Loading branch information
HayaAlmutairi committed Nov 26, 2021
1 parent b8c9aff commit a8bd6d2
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 11 deletions.
5 changes: 5 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ Changelog

- |Fix| :class:`decomposition.FastICA` now validates input parameters in `fit` instead of `__init__`.
:pr:`21432` by :user:`Hannah Bohle <hhnnhh>` and :user:`Maren Westermann <marenwestermann>`.

- |Fix| :class:`decomposition.FactorAnalysis` now validates input parameters
in `fit` instead of `__init__`.
:pr:`21713` by :user:`Haya <HayaAlmutairi>` and
:user:`Krum Arnaudov <krumeto>`.

- |Fix| :class:`decomposition.KernelPCA` now validates input parameters in
`fit` instead of `__init__`.
Expand Down
12 changes: 7 additions & 5 deletions sklearn/decomposition/_factor_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,6 @@ def __init__(
self.copy = copy
self.tol = tol
self.max_iter = max_iter
if svd_method not in ["lapack", "randomized"]:
raise ValueError(
"SVD method %s is not supported. Please consider the documentation"
% svd_method
)
self.svd_method = svd_method

self.noise_variance_init = noise_variance_init
Expand All @@ -204,6 +199,13 @@ def fit(self, X, y=None):
self : object
FactorAnalysis class instance.
"""

if self.svd_method not in ["lapack", "randomized"]:
raise ValueError(
f"SVD method {self.svd_method!r} is not supported. Possible methods "
"are either 'lapack' or 'randomized'."
)

X = self._validate_data(X, copy=self.copy, dtype=np.float64)

n_samples, n_features = X.shape
Expand Down
8 changes: 3 additions & 5 deletions sklearn/decomposition/tests/test_factor_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@ def test_factor_analysis():
# wlog, mean is 0
X = np.dot(h, W) + noise

with pytest.raises(ValueError):
FactorAnalysis(svd_method="foo")
fa_fail = FactorAnalysis()
fa_fail.svd_method = "foo"
with pytest.raises(ValueError):
fa_fail = FactorAnalysis(svd_method="foo")
msg = "SVD method 'foo' is not supported"
with pytest.raises(ValueError, match=msg):
fa_fail.fit(X)
fas = []
for method in ["randomized", "lapack"]:
Expand Down
1 change: 0 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ def test_transformers_get_feature_names_out(transformer):

VALIDATE_ESTIMATOR_INIT = [
"ColumnTransformer",
"FactorAnalysis",
"FeatureHasher",
"FeatureUnion",
"GridSearchCV",
Expand Down

0 comments on commit a8bd6d2

Please sign in to comment.