Skip to content

Commit

Permalink
MAINT Clean up deprecations for 1.5: in AdditiveChi2Sampler (scikit-l…
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiedbb committed Apr 26, 2024
1 parent f4cc029 commit c35a719
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 71 deletions.
81 changes: 23 additions & 58 deletions sklearn/kernel_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
_fit_context,
)
from .metrics.pairwise import KERNEL_PARAMS, PAIRWISE_KERNEL_FUNCTIONS, pairwise_kernels
from .utils import check_random_state, deprecated
from .utils import check_random_state
from .utils._param_validation import Interval, StrOptions
from .utils.extmath import safe_sparse_dot
from .utils.validation import (
Expand Down Expand Up @@ -600,13 +600,6 @@ class AdditiveChi2Sampler(TransformerMixin, BaseEstimator):
Attributes
----------
sample_interval_ : float
Stored sampling interval. Specified as a parameter if `sample_steps`
not in {1,2,3}.
.. deprecated:: 1.3
`sample_interval_` serves internal purposes only and will be removed in 1.5.
n_features_in_ : int
Number of features seen during :term:`fit`.
Expand Down Expand Up @@ -693,37 +686,14 @@ def fit(self, X, y=None):
X = self._validate_data(X, accept_sparse="csr")
check_non_negative(X, "X in AdditiveChi2Sampler.fit")

# TODO(1.5): remove the setting of _sample_interval from fit
if self.sample_interval is None:
# See figure 2 c) of "Efficient additive kernels via explicit feature maps"
# <http:https://www.robots.ox.ac.uk/~vedaldi/assets/pubs/vedaldi11efficient.pdf>
# A. Vedaldi and A. Zisserman, Pattern Analysis and Machine Intelligence,
# 2011
if self.sample_steps == 1:
self._sample_interval = 0.8
elif self.sample_steps == 2:
self._sample_interval = 0.5
elif self.sample_steps == 3:
self._sample_interval = 0.4
else:
raise ValueError(
"If sample_steps is not in [1, 2, 3],"
" you need to provide sample_interval"
)
else:
self._sample_interval = self.sample_interval
if self.sample_interval is None and self.sample_steps not in (1, 2, 3):
raise ValueError(
"If sample_steps is not in [1, 2, 3],"
" you need to provide sample_interval"
)

return self

# TODO(1.5): remove
@deprecated( # type: ignore
"The ``sample_interval_`` attribute was deprecated in version 1.3 and "
"will be removed 1.5."
)
@property
def sample_interval_(self):
return self._sample_interval

def transform(self, X):
"""Apply approximate feature map to X.
Expand All @@ -744,29 +714,24 @@ def transform(self, X):
check_non_negative(X, "X in AdditiveChi2Sampler.transform")
sparse = sp.issparse(X)

if hasattr(self, "_sample_interval"):
# TODO(1.5): remove this branch
sample_interval = self._sample_interval

else:
if self.sample_interval is None:
# See figure 2 c) of "Efficient additive kernels via explicit feature maps" # noqa
# <http:https://www.robots.ox.ac.uk/~vedaldi/assets/pubs/vedaldi11efficient.pdf>
# A. Vedaldi and A. Zisserman, Pattern Analysis and Machine Intelligence, # noqa
# 2011
if self.sample_steps == 1:
sample_interval = 0.8
elif self.sample_steps == 2:
sample_interval = 0.5
elif self.sample_steps == 3:
sample_interval = 0.4
else:
raise ValueError(
"If sample_steps is not in [1, 2, 3],"
" you need to provide sample_interval"
)
if self.sample_interval is None:
# See figure 2 c) of "Efficient additive kernels via explicit feature maps" # noqa
# <http:https://www.robots.ox.ac.uk/~vedaldi/assets/pubs/vedaldi11efficient.pdf>
# A. Vedaldi and A. Zisserman, Pattern Analysis and Machine Intelligence, # noqa
# 2011
if self.sample_steps == 1:
sample_interval = 0.8
elif self.sample_steps == 2:
sample_interval = 0.5
elif self.sample_steps == 3:
sample_interval = 0.4
else:
sample_interval = self.sample_interval
raise ValueError(
"If sample_steps is not in [1, 2, 3],"
" you need to provide sample_interval"
)
else:
sample_interval = self.sample_interval

# zeroth component
# 1/cosh = sech
Expand Down
13 changes: 0 additions & 13 deletions sklearn/tests/test_kernel_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,6 @@ def test_additive_chi2_sampler_sample_steps(method, sample_steps):
assert transformer.sample_interval == sample_interval


# TODO(1.5): remove
def test_additive_chi2_sampler_future_warnings():
"""Check that we raise a FutureWarning when accessing to `sample_interval_`."""
transformer = AdditiveChi2Sampler()
transformer.fit(X)
msg = re.escape(
"The ``sample_interval_`` attribute was deprecated in version 1.3 and "
"will be removed 1.5."
)
with pytest.warns(FutureWarning, match=msg):
assert transformer.sample_interval_ is not None


@pytest.mark.parametrize("method", ["fit", "fit_transform", "transform"])
def test_additive_chi2_sampler_wrong_sample_steps(method):
"""Check that we raise a ValueError on invalid sample_steps"""
Expand Down

0 comments on commit c35a719

Please sign in to comment.