Skip to content

Commit

Permalink
Fix use metric_kwargs in check_array_api_metric_pairwise (scikit-lear…
Browse files Browse the repository at this point in the history
  • Loading branch information
EdAbati committed May 20, 2024
1 parent 18c1972 commit 3a023b0
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,12 +1898,24 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
if "dense_output" in signature(metric).parameters:
metric_kwargs["dense_output"] = False
check_array_api_metric(
metric, array_namespace, device, dtype_name, a_np=X_np, b_np=Y_np
metric,
array_namespace,
device,
dtype_name,
a_np=X_np,
b_np=Y_np,
**metric_kwargs,
)
metric_kwargs["dense_output"] = True

check_array_api_metric(
metric, array_namespace, device, dtype_name, a_np=X_np, b_np=Y_np
metric,
array_namespace,
device,
dtype_name,
a_np=X_np,
b_np=Y_np,
**metric_kwargs,
)


Expand Down

0 comments on commit 3a023b0

Please sign in to comment.