Skip to content

Commit

Permalink
TST Make test_neighbors_metric robust to rng (scikit-learn#28888)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiedbb committed Apr 26, 2024
1 parent 62818c3 commit bd8ca47
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions sklearn/neighbors/tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,8 +1644,16 @@ def test_nearest_neighbors_validate_params():
+ DISTANCE_METRIC_OBJS,
)
def test_neighbors_metrics(
global_dtype, metric, n_samples=20, n_features=3, n_query_pts=2, n_neighbors=5
global_dtype,
global_random_seed,
metric,
n_samples=20,
n_features=3,
n_query_pts=2,
n_neighbors=5,
):
rng = np.random.RandomState(global_random_seed)

metric = _parse_metric(metric, global_dtype)

# Test computing the neighbors for various metrics
Expand Down Expand Up @@ -1697,15 +1705,19 @@ def test_neighbors_metrics(
brute_dst, brute_idx = results["brute"]
ball_tree_dst, ball_tree_idx = results["ball_tree"]

assert_allclose(brute_dst, ball_tree_dst)
# The returned distances are always in float64 regardless of the input dtype
# We need to adjust the tolerance w.r.t the input dtype
rtol = 1e-7 if global_dtype == np.float64 else 1e-4

assert_allclose(brute_dst, ball_tree_dst, rtol=rtol)
assert_array_equal(brute_idx, ball_tree_idx)

if not exclude_kd_tree:
kd_tree_dst, kd_tree_idx = results["kd_tree"]
assert_allclose(brute_dst, kd_tree_dst)
assert_allclose(brute_dst, kd_tree_dst, rtol=rtol)
assert_array_equal(brute_idx, kd_tree_idx)

assert_allclose(ball_tree_dst, kd_tree_dst)
assert_allclose(ball_tree_dst, kd_tree_dst, rtol=rtol)
assert_array_equal(ball_tree_idx, kd_tree_idx)


Expand Down

0 comments on commit bd8ca47

Please sign in to comment.