forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TST introducing the random_seed fixture (scikit-learn#22749)
Co-authored-by: Julien Jerphanion <[email protected]> Co-authored-by: Thomas J. Fan <[email protected]> Co-authored-by: Jérémie du Boisberranger <[email protected]>
- Loading branch information
1 parent
6904ae3
commit d3429ca
Showing
7 changed files
with
162 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
"""global_random_seed fixture | ||
The goal of this fixture is to prevent tests that use it to be sensitive | ||
to a specific seed value while still being deterministic by default. | ||
See the documentation for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED | ||
variable for insrtuctions on how to use this fixture. | ||
https://scikit-learn.org/dev/computing/parallelism.html#environment-variables | ||
""" | ||
import pytest | ||
from os import environ | ||
from random import Random | ||
|
||
|
||
# Passes the main worker's random seeds to workers | ||
class XDistHooks: | ||
def pytest_configure_node(self, node) -> None: | ||
random_seeds = node.config.getoption("random_seeds") | ||
node.workerinput["random_seeds"] = random_seeds | ||
|
||
|
||
def pytest_configure(config): | ||
if config.pluginmanager.hasplugin("xdist"): | ||
config.pluginmanager.register(XDistHooks()) | ||
|
||
RANDOM_SEED_RANGE = list(range(100)) # All seeds in [0, 99] should be valid. | ||
random_seed_var = environ.get("SKLEARN_TESTS_GLOBAL_RANDOM_SEED") | ||
if hasattr(config, "workinput"): | ||
# Set worker random seed from seed generated from main process | ||
random_seeds = config.workerinput["random_seeds"] | ||
elif random_seed_var is None: | ||
# This is the way. | ||
random_seeds = [42] | ||
elif random_seed_var == "any": | ||
# Pick-up one seed at random in the range of admissible random seeds. | ||
random_seeds = [Random().choice(RANDOM_SEED_RANGE)] | ||
elif random_seed_var == "all": | ||
random_seeds = RANDOM_SEED_RANGE | ||
else: | ||
if "-" in random_seed_var: | ||
start, stop = random_seed_var.split("-") | ||
random_seeds = list(range(int(start), int(stop) + 1)) | ||
else: | ||
random_seeds = [int(random_seed_var)] | ||
|
||
if min(random_seeds) < 0 or max(random_seeds) > 99: | ||
raise ValueError( | ||
"The value(s) of the environment variable " | ||
"SKLEARN_TESTS_GLOBAL_RANDOM_SEED must be in the range [0, 99] " | ||
f"(or 'any' or 'all'), got: {random_seed_var}" | ||
) | ||
config.option.random_seeds = random_seeds | ||
|
||
class GlobalRandomSeedPlugin: | ||
@pytest.fixture(params=random_seeds) | ||
def global_random_seed(self, request): | ||
"""Fixture to ask for a random yet controllable random seed. | ||
All tests that use this fixture accept the contract that they should | ||
deterministically pass for any seed value from 0 to 99 included. | ||
See the documentation for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED | ||
variable for insrtuctions on how to use this fixture. | ||
https://scikit-learn.org/dev/computing/parallelism.html#environment-variables | ||
""" | ||
yield request.param | ||
|
||
config.pluginmanager.register(GlobalRandomSeedPlugin()) | ||
|
||
|
||
def pytest_report_header(config): | ||
random_seed_var = environ.get("SKLEARN_TESTS_GLOBAL_RANDOM_SEED") | ||
if random_seed_var == "any": | ||
return [ | ||
"To reproduce this test run, set the following environment variable:", | ||
f' SKLEARN_TESTS_GLOBAL_RANDOM_SEED="{config.option.random_seeds[0]}"', | ||
"See: https://scikit-learn.org/dev/computing/parallelism.html" | ||
"#environment-variables", | ||
] |