Skip to content

Commit

Permalink
FIX Lazy instantiate the ThreadpoolController (scikit-learn#29235)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiedbb committed Jun 13, 2024
1 parent 55ca335 commit 6595229
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 38 deletions.
6 changes: 6 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ Version 1.5.1
Changelog
---------

Changes impacting many modules
------------------------------

- |Fix| Fixed a regression causing a dead-lock at import time in some settings.
:pr:`29235` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.metrics`
......................

Expand Down
15 changes: 0 additions & 15 deletions sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,6 @@
# We are not importing the rest of scikit-learn during the build
# process, as it may not be compiled yet
else:
# Import numpy, scipy to make sure that the BLAS libs are loaded before
# creating the ThreadpoolController. They would be imported just after
# when importing utils anyway. This makes it explicit and robust to changes
# in utils.
# (OpenMP is loaded by importing show_versions right after this block)
import numpy # noqa
import scipy.linalg # noqa
from threadpoolctl import ThreadpoolController

# `_distributor_init` allows distributors to run custom init code.
# For instance, for the Windows wheel, this is used to pre-load the
# vcomp shared library runtime for OpenMP embedded in the sklearn/.libs
Expand Down Expand Up @@ -147,12 +138,6 @@
except ModuleNotFoundError:
pass

# Set a global controller that can be used to locally limit the number of
# threads without looping through all shared libraries every time.
# This instantitation should not happen earlier because it needs all BLAS and
# OpenMP libs to be loaded first.
_threadpool_controller = ThreadpoolController()


def setup_module(module):
"""Fixture for the tests to assure globally controllable seeding of RNGs"""
Expand Down
15 changes: 9 additions & 6 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import numpy as np
import scipy.sparse as sp

from .. import _threadpool_controller
from ..base import (
BaseEstimator,
ClassNamePrefixFeaturesOutMixin,
Expand All @@ -32,6 +31,10 @@
from ..utils._openmp_helpers import _openmp_effective_n_threads
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils.extmath import row_norms, stable_cumsum
from ..utils.parallel import (
_get_threadpool_controller,
_threadpool_controller_decorator,
)
from ..utils.sparsefuncs import mean_variance_axis
from ..utils.sparsefuncs_fast import assign_rows_csr
from ..utils.validation import (
Expand Down Expand Up @@ -624,7 +627,7 @@ def _kmeans_single_elkan(

# Threadpoolctl context to limit the number of threads in second level of
# nested parallelism (i.e. BLAS) to avoid oversubscription.
@_threadpool_controller.wrap(limits=1, user_api="blas")
@_threadpool_controller_decorator(limits=1, user_api="blas")
def _kmeans_single_lloyd(
X,
sample_weight,
Expand Down Expand Up @@ -827,7 +830,7 @@ def _labels_inertia(X, sample_weight, centers, n_threads=1, return_inertia=True)


# Same as _labels_inertia but in a threadpool_limits context.
_labels_inertia_threadpool_limit = _threadpool_controller.wrap(
_labels_inertia_threadpool_limit = _threadpool_controller_decorator(
limits=1, user_api="blas"
)(_labels_inertia)

Expand Down Expand Up @@ -922,7 +925,7 @@ def _check_mkl_vcomp(self, X, n_samples):

n_active_threads = int(np.ceil(n_samples / CHUNK_SIZE))
if n_active_threads < self._n_threads:
modules = _threadpool_controller.info()
modules = _get_threadpool_controller().info()
has_vcomp = "vcomp" in [module["prefix"] for module in modules]
has_mkl = ("mkl", "intel") in [
(module["internal_api"], module.get("threading_layer", None))
Expand Down Expand Up @@ -2144,7 +2147,7 @@ def fit(self, X, y=None, sample_weight=None):

n_steps = (self.max_iter * n_samples) // self._batch_size

with _threadpool_controller.limit(limits=1, user_api="blas"):
with _get_threadpool_controller().limit(limits=1, user_api="blas"):
# Perform the iterative optimization until convergence
for i in range(n_steps):
# Sample a minibatch from the full dataset
Expand Down Expand Up @@ -2270,7 +2273,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
# Initialize number of samples seen since last reassignment
self._n_since_last_reassign = 0

with _threadpool_controller.limit(limits=1, user_api="blas"):
with _get_threadpool_controller().limit(limits=1, user_api="blas"):
_mini_batch_step(
X,
sample_weight=sample_weight,
Expand Down
6 changes: 3 additions & 3 deletions sklearn/cluster/tests/test_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytest
from scipy import sparse as sp

from sklearn import _threadpool_controller
from sklearn.base import clone
from sklearn.cluster import KMeans, MiniBatchKMeans, k_means, kmeans_plusplus
from sklearn.cluster._k_means_common import (
Expand All @@ -33,6 +32,7 @@
)
from sklearn.utils.extmath import row_norms
from sklearn.utils.fixes import CSR_CONTAINERS
from sklearn.utils.parallel import _get_threadpool_controller

# non centered, sparse centers to check the
centers = np.array(
Expand Down Expand Up @@ -968,13 +968,13 @@ def test_result_equal_in_diff_n_threads(Estimator, global_random_seed):
rnd = np.random.RandomState(global_random_seed)
X = rnd.normal(size=(50, 10))

with _threadpool_controller.limit(limits=1, user_api="openmp"):
with _get_threadpool_controller().limit(limits=1, user_api="openmp"):
result_1 = (
Estimator(n_clusters=n_clusters, random_state=global_random_seed)
.fit(X)
.labels_
)
with _threadpool_controller.limit(limits=2, user_api="openmp"):
with _get_threadpool_controller().limit(limits=2, user_api="openmp"):
result_2 = (
Estimator(n_clusters=n_clusters, random_state=global_random_seed)
.fit(X)
Expand Down
4 changes: 2 additions & 2 deletions sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ from numbers import Integral
from scipy.sparse import issparse
from ...utils import check_array, check_scalar
from ...utils.fixes import _in_unstable_openblas_configuration
from ... import _threadpool_controller
from ...utils.parallel import _get_threadpool_controller

{{for name_suffix in ['64', '32']}}

Expand Down Expand Up @@ -58,7 +58,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
"""
# Limit the number of threads in second level of nested parallelism for BLAS
# to avoid threads over-subscription (in DOT or GEMM for instance).
with _threadpool_controller.limit(limits=1, user_api='blas'):
with _get_threadpool_controller().limit(limits=1, user_api='blas'):
if metric in ("euclidean", "sqeuclidean"):
# Specialized implementation of ArgKmin for the Euclidean distance
# for the dense-dense and sparse-sparse cases.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ from libcpp.map cimport map as cpp_map, pair as cpp_pair
from libc.stdlib cimport free

from ...utils._typedefs cimport intp_t, float64_t
from ... import _threadpool_controller
from ...utils.parallel import _get_threadpool_controller

import numpy as np
from scipy.sparse import issparse
Expand Down Expand Up @@ -66,7 +66,7 @@ cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):

# Limit the number of threads in second level of nested parallelism for BLAS
# to avoid threads over-subscription (in GEMM for instance).
with _threadpool_controller.limit(limits=1, user_api="blas"):
with _get_threadpool_controller().limit(limits=1, user_api="blas"):
if pda.execute_in_parallel_on_Y:
pda._parallel_on_Y()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from numbers import Real
from scipy.sparse import issparse
from ...utils import check_array, check_scalar
from ...utils.fixes import _in_unstable_openblas_configuration
from ... import _threadpool_controller
from ...utils.parallel import _get_threadpool_controller

cnp.import_array()

Expand Down Expand Up @@ -110,7 +110,7 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}})

# Limit the number of threads in second level of nested parallelism for BLAS
# to avoid threads over-subscription (in GEMM for instance).
with _threadpool_controller.limit(limits=1, user_api="blas"):
with _get_threadpool_controller().limit(limits=1, user_api="blas"):
if pda.execute_in_parallel_on_Y:
pda._parallel_on_Y()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ from ...utils._typedefs cimport intp_t, float64_t

import numpy as np
from scipy.sparse import issparse
from ... import _threadpool_controller
from ...utils.parallel import _get_threadpool_controller


{{for name_suffix in ["32", "64"]}}
Expand Down Expand Up @@ -60,7 +60,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix}

# Limit the number of threads in second level of nested parallelism for BLAS
# to avoid threads over-subscription (in GEMM for instance).
with _threadpool_controller.limit(limits=1, user_api="blas"):
with _get_threadpool_controller().limit(limits=1, user_api="blas"):
if pda.execute_in_parallel_on_Y:
pda._parallel_on_Y()
else:
Expand Down
4 changes: 2 additions & 2 deletions sklearn/metrics/tests/test_pairwise_distances_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pytest
from scipy.spatial.distance import cdist

from sklearn import _threadpool_controller
from sklearn.metrics import euclidean_distances, pairwise_distances
from sklearn.metrics._pairwise_distances_reduction import (
ArgKmin,
Expand All @@ -23,6 +22,7 @@
create_memmap_backed_data,
)
from sklearn.utils.fixes import CSR_CONTAINERS
from sklearn.utils.parallel import _get_threadpool_controller

# Common supported metric between scipy.spatial.distance.cdist
# and BaseDistanceReductionDispatcher.
Expand Down Expand Up @@ -1200,7 +1200,7 @@ def test_n_threads_agnosticism(
**compute_parameters,
)

with _threadpool_controller.limit(limits=1, user_api="openmp"):
with _get_threadpool_controller().limit(limits=1, user_api="openmp"):
dist, indices = Dispatcher.compute(
X,
Y,
Expand Down
5 changes: 2 additions & 3 deletions sklearn/utils/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
import scipy.sparse.linalg
import scipy.stats

import sklearn

from ..externals._packaging.version import parse as parse_version
from .parallel import _get_threadpool_controller

_IS_32BIT = 8 * struct.calcsize("P") == 32
_IS_WASM = platform.machine() in ["wasm32", "wasm64"]
Expand Down Expand Up @@ -390,7 +389,7 @@ def _in_unstable_openblas_configuration():
import numpy # noqa
import scipy # noqa

modules_info = sklearn._threadpool_controller.info()
modules_info = _get_threadpool_controller().info()

open_blas_used = any(info["internal_api"] == "openblas" for info in modules_info)
if not open_blas_used:
Expand Down
41 changes: 40 additions & 1 deletion sklearn/utils/parallel.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
"""Customizations of :mod:`joblib` tools for scikit-learn usage."""
"""Customizations of :mod:`joblib` and :mod:`threadpoolctl` tools for scikit-learn
usage.
"""

import functools
import warnings
from functools import update_wrapper

import joblib
from threadpoolctl import ThreadpoolController

from .._config import config_context, get_config

# Global threadpool controller instance that can be used to locally limit the number of
# threads without looping through all shared libraries every time.
# It should not be accessed directly and _get_threadpool_controller should be used
# instead.
_threadpool_controller = None


def _with_config(delayed_func, config):
"""Helper function that intends to attach a config to a delayed function."""
Expand Down Expand Up @@ -125,3 +134,33 @@ def __call__(self, *args, **kwargs):
config = {}
with config_context(**config):
return self.function(*args, **kwargs)


def _get_threadpool_controller():
"""Return the global threadpool controller instance."""
global _threadpool_controller

if _threadpool_controller is None:
_threadpool_controller = ThreadpoolController()

return _threadpool_controller


def _threadpool_controller_decorator(limits=1, user_api="blas"):
"""Decorator to limit the number of threads used at the function level.
It should be prefered over `threadpoolctl.ThreadpoolController.wrap` because this
one only loads the shared libraries when the function is called while the latter
loads them at import time.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
controller = _get_threadpool_controller()
with controller.limit(limits=limits, user_api=user_api):
return func(*args, **kwargs)

return wrapper

return decorator

0 comments on commit 6595229

Please sign in to comment.