diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index 3a21304a39a3e..310df6b12a6ec 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -108,6 +108,7 @@ Metrics - :func:`sklearn.metrics.accuracy_score` - :func:`sklearn.metrics.mean_absolute_error` - :func:`sklearn.metrics.mean_tweedie_deviance` +- :func:`sklearn.metrics.pairwise.cosine_similarity`` - :func:`sklearn.metrics.r2_score` - :func:`sklearn.metrics.zero_one_loss` diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index aff2ea2b011da..601868a9a9581 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -34,8 +34,10 @@ See :ref:`array_api` for more details. - :func:`sklearn.metrics.mean_tweedie_deviance` now supports Array API compatible inputs. - :pr:`28106` by :user:`Thomas Li ` -- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati `. + :pr:`28106` by :user:`Thomas Li `; +- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati `; +- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati `. + **Classes:** diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index d30c1775823a5..ff158825cc0f9 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -25,6 +25,11 @@ gen_batches, gen_even_slices, ) +from ..utils._array_api import ( + _find_matching_floating_dtype, + _is_numpy_namespace, + get_namespace, +) from ..utils._chunking import get_chunk_n_rows from ..utils._mask import _get_mask from ..utils._missing import is_scalar_nan @@ -154,7 +159,11 @@ def check_pairwise_arrays( An array equal to Y if Y was not None, guaranteed to be a numpy array. If Y was None, safe_Y will be a pointer to X. """ - X, Y, dtype_float = _return_float_dtype(X, Y) + xp, _ = get_namespace(X, Y) + if any([issparse(X), issparse(Y)]) or _is_numpy_namespace(xp): + X, Y, dtype_float = _return_float_dtype(X, Y) + else: + dtype_float = _find_matching_floating_dtype(X, Y, xp=xp) estimator = "check_pairwise_arrays" if dtype == "infer_float": diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index ae47ffe3d6a56..9e94b9241de7a 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -51,6 +51,7 @@ zero_one_loss, ) from sklearn.metrics._base import _average_binary_score +from sklearn.metrics.pairwise import cosine_similarity from sklearn.preprocessing import LabelBinarizer from sklearn.utils import shuffle from sklearn.utils._array_api import ( @@ -1743,20 +1744,22 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str): def check_array_api_metric( - metric, array_namespace, device, dtype_name, y_true_np, y_pred_np, sample_weight + metric, array_namespace, device, dtype_name, a_np, b_np, **metric_kwargs ): xp = _array_api_for_tests(array_namespace, device) - y_true_xp = xp.asarray(y_true_np, device=device) - y_pred_xp = xp.asarray(y_pred_np, device=device) + a_xp = xp.asarray(a_np, device=device) + b_xp = xp.asarray(b_np, device=device) - metric_np = metric(y_true_np, y_pred_np, sample_weight=sample_weight) + metric_np = metric(a_np, b_np, **metric_kwargs) - if sample_weight is not None: - sample_weight = xp.asarray(sample_weight, device=device) + if metric_kwargs.get("sample_weight") is not None: + metric_kwargs["sample_weight"] = xp.asarray( + metric_kwargs["sample_weight"], device=device + ) with config_context(array_api_dispatch=True): - metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight) + metric_xp = metric(a_xp, b_xp, **metric_kwargs) assert_allclose( _convert_to_numpy(xp.asarray(metric_xp), xp), @@ -1776,8 +1779,8 @@ def check_array_api_binary_classification_metric( array_namespace, device, dtype_name, - y_true_np=y_true_np, - y_pred_np=y_pred_np, + a_np=y_true_np, + b_np=y_pred_np, sample_weight=None, ) @@ -1788,8 +1791,8 @@ def check_array_api_binary_classification_metric( array_namespace, device, dtype_name, - y_true_np=y_true_np, - y_pred_np=y_pred_np, + a_np=y_true_np, + b_np=y_pred_np, sample_weight=sample_weight, ) @@ -1805,8 +1808,8 @@ def check_array_api_multiclass_classification_metric( array_namespace, device, dtype_name, - y_true_np=y_true_np, - y_pred_np=y_pred_np, + a_np=y_true_np, + b_np=y_pred_np, sample_weight=None, ) @@ -1817,8 +1820,8 @@ def check_array_api_multiclass_classification_metric( array_namespace, device, dtype_name, - y_true_np=y_true_np, - y_pred_np=y_pred_np, + a_np=y_true_np, + b_np=y_pred_np, sample_weight=sample_weight, ) @@ -1832,8 +1835,8 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam array_namespace, device, dtype_name, - y_true_np=y_true_np, - y_pred_np=y_pred_np, + a_np=y_true_np, + b_np=y_pred_np, sample_weight=None, ) @@ -1844,8 +1847,8 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam array_namespace, device, dtype_name, - y_true_np=y_true_np, - y_pred_np=y_pred_np, + a_np=y_true_np, + b_np=y_pred_np, sample_weight=sample_weight, ) @@ -1861,8 +1864,8 @@ def check_array_api_regression_metric_multioutput( array_namespace, device, dtype_name, - y_true_np=y_true_np, - y_pred_np=y_pred_np, + a_np=y_true_np, + b_np=y_pred_np, sample_weight=None, ) @@ -1873,8 +1876,8 @@ def check_array_api_regression_metric_multioutput( array_namespace, device, dtype_name, - y_true_np=y_true_np, - y_pred_np=y_pred_np, + a_np=y_true_np, + b_np=y_pred_np, sample_weight=sample_weight, ) @@ -1886,6 +1889,20 @@ def check_array_api_multioutput_regression_metric( check_array_api_regression_metric(metric, array_namespace, device, dtype_name) +def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name): + + X_np = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=dtype_name) + Y_np = np.array([[0.2, 0.3, 0.4], [0.5, 0.6, 0.7]], dtype=dtype_name) + + metric_kwargs = {} + if "dense_output" in signature(metric).parameters: + metric_kwargs["dense_output"] = True + + check_array_api_metric( + metric, array_namespace, device, dtype_name, a_np=X_np, b_np=Y_np + ) + + array_api_metric_checkers = { accuracy_score: [ check_array_api_binary_classification_metric, @@ -1900,6 +1917,7 @@ def check_array_api_multioutput_regression_metric( check_array_api_regression_metric, check_array_api_regression_metric_multioutput, ], + cosine_similarity: [check_array_api_metric_pairwise], mean_absolute_error: [ check_array_api_regression_metric, check_array_api_multioutput_regression_metric,