diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index dadae86689e08..3a21304a39a3e 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -106,6 +106,7 @@ Metrics ------- - :func:`sklearn.metrics.accuracy_score` +- :func:`sklearn.metrics.mean_absolute_error` - :func:`sklearn.metrics.mean_tweedie_deviance` - :func:`sklearn.metrics.r2_score` - :func:`sklearn.metrics.zero_one_loss` diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 5000866b59c03..b4d26e07dffc0 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -35,6 +35,7 @@ See :ref:`array_api` for more details. - :func:`sklearn.metrics.mean_tweedie_deviance` now supports Array API compatible inputs. :pr:`28106` by :user:`Thomas Li ` +- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati `. **Classes:** diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 596a45dd3eaed..61bb1caa2d9da 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -189,7 +189,7 @@ def mean_absolute_error( Returns ------- - loss : float or ndarray of floats + loss : float or array of floats If multioutput is 'raw_values', then mean absolute error is returned for each output separately. If multioutput is 'uniform_average' or an ndarray of weights, then the @@ -213,11 +213,19 @@ def mean_absolute_error( >>> mean_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7]) 0.85... """ - y_type, y_true, y_pred, multioutput = _check_reg_targets( - y_true, y_pred, multioutput + input_arrays = [y_true, y_pred, sample_weight, multioutput] + xp, _ = get_namespace(*input_arrays) + + dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp) + + _, y_true, y_pred, multioutput = _check_reg_targets( + y_true, y_pred, multioutput, dtype=dtype, xp=xp ) check_consistent_length(y_true, y_pred, sample_weight) - output_errors = np.average(np.abs(y_pred - y_true), weights=sample_weight, axis=0) + + output_errors = _average( + xp.abs(y_pred - y_true), weights=sample_weight, axis=0, xp=xp + ) if isinstance(multioutput, str): if multioutput == "raw_values": return output_errors @@ -225,7 +233,15 @@ def mean_absolute_error( # pass None as weights to np.average: uniform mean multioutput = None - return np.average(output_errors, weights=multioutput) + # Average across the outputs (if needed). + mean_absolute_error = _average(output_errors, weights=multioutput) + + # Since `y_pred.ndim <= 2` and `y_true.ndim <= 2`, the second call to _average + # should always return a scalar array that we convert to a Python float to + # consistently return the same eager evaluated value, irrespective of the + # Array API implementation. + assert mean_absolute_error.shape == () + return float(mean_absolute_error) @validate_params( diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index f00af5e160858..ae47ffe3d6a56 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1879,6 +1879,13 @@ def check_array_api_regression_metric_multioutput( ) +def check_array_api_multioutput_regression_metric( + metric, array_namespace, device, dtype_name +): + metric = partial(metric, multioutput="raw_values") + check_array_api_regression_metric(metric, array_namespace, device, dtype_name) + + array_api_metric_checkers = { accuracy_score: [ check_array_api_binary_classification_metric, @@ -1893,6 +1900,10 @@ def check_array_api_regression_metric_multioutput( check_array_api_regression_metric, check_array_api_regression_metric_multioutput, ], + mean_absolute_error: [ + check_array_api_regression_metric, + check_array_api_multioutput_regression_metric, + ], }