Skip to content

Commit

Permalink
ENH Optimize runtime for IsolationForest (scikit-learn#23252)
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas J. Fan <[email protected]>
  • Loading branch information
MaxwellLZH and thomasjpfan authored May 4, 2022
1 parent d3050e4 commit abbeacc
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 5 deletions.
6 changes: 6 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ Changelog
- |Enhancement| :class:`cluster.Birch` now preserves dtype for `numpy.float32`
inputs. :pr:`22968` by `Meekail Zain <micky774>`.

:mod:`sklearn.ensemble`
.......................

- |Efficiency| Improve runtime performance of :class:`ensemble.IsolationForest`
by avoiding data copies. :pr:`23252` by :user:`Zhehao Liu <MaxwellLZH>`.

Code and Documentation Contributors
-----------------------------------

Expand Down
9 changes: 6 additions & 3 deletions sklearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def _parallel_build_estimators(
bootstrap_features = ensemble.bootstrap_features
support_sample_weight = has_fit_parameter(ensemble.base_estimator_, "sample_weight")
has_check_input = has_fit_parameter(ensemble.base_estimator_, "check_input")
requires_feature_indexing = bootstrap_features or max_features != n_features

if not support_sample_weight and sample_weight is not None:
raise ValueError("The base estimator doesn't support sample weight")

Expand Down Expand Up @@ -135,10 +137,11 @@ def _parallel_build_estimators(
not_indices_mask = ~indices_to_mask(indices, n_samples)
curr_sample_weight[not_indices_mask] = 0

estimator_fit(X[:, features], y, sample_weight=curr_sample_weight)

X_ = X[:, features] if requires_feature_indexing else X
estimator_fit(X_, y, sample_weight=curr_sample_weight)
else:
estimator_fit(X[indices][:, features], y[indices])
X_ = X[indices][:, features] if requires_feature_indexing else X[indices]
estimator_fit(X_, y[indices])

estimators.append(estimator)
estimators_features.append(features)
Expand Down
6 changes: 5 additions & 1 deletion sklearn/ensemble/_iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from warnings import warn

from ..tree import ExtraTreeRegressor
from ..tree._tree import DTYPE as tree_dtype
from ..utils import (
check_random_state,
check_array,
Expand Down Expand Up @@ -80,6 +81,9 @@ class IsolationForest(OutlierMixin, BaseBagging):
- If int, then draw `max_features` features.
- If float, then draw `max_features * X.shape[1]` features.
Note: using a float number less than 1.0 or integer less than number of
features will enable feature subsampling and leads to a longerr runtime.
bootstrap : bool, default=False
If True, individual trees are fit on random subsets of the training
data sampled with replacement. If False, sampling without replacement
Expand Down Expand Up @@ -254,7 +258,7 @@ def fit(self, X, y=None, sample_weight=None):
self : object
Fitted estimator.
"""
X = self._validate_data(X, accept_sparse=["csc"])
X = self._validate_data(X, accept_sparse=["csc"], dtype=tree_dtype)
if issparse(X):
# Pre-sort indices to avoid that each individual tree of the
# ensemble sorts the indices.
Expand Down
12 changes: 11 additions & 1 deletion sklearn/ensemble/tests/test_iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sklearn.ensemble import IsolationForest
from sklearn.ensemble._iforest import _average_path_length
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes, load_iris
from sklearn.datasets import load_diabetes, load_iris, make_classification
from sklearn.utils import check_random_state
from sklearn.metrics import roc_auc_score

Expand Down Expand Up @@ -347,3 +347,13 @@ def test_n_features_deprecation():

with pytest.warns(FutureWarning, match="`n_features_` was deprecated"):
est.n_features_


def test_iforest_with_n_jobs_does_not_segfault():
"""Check that Isolation Forest does not segfault with n_jobs=2
Non-regression test for #23252
"""
X, _ = make_classification(n_samples=85_000, n_features=100, random_state=0)
X = csc_matrix(X)
IsolationForest(n_estimators=10, max_samples=256, n_jobs=2).fit(X)

0 comments on commit abbeacc

Please sign in to comment.