Skip to content

Commit

Permalink
ENH Add Array API compatibility to mean_absolute_error (scikit-lear…
Browse files Browse the repository at this point in the history
  • Loading branch information
EdAbati committed May 15, 2024
1 parent 28c9f50 commit 9f44f1f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 5 deletions.
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Metrics
-------

- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.mean_absolute_error`
- :func:`sklearn.metrics.mean_tweedie_deviance`
- :func:`sklearn.metrics.r2_score`
- :func:`sklearn.metrics.zero_one_loss`
Expand Down
1 change: 1 addition & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ See :ref:`array_api` for more details.
- :func:`sklearn.metrics.mean_tweedie_deviance` now supports Array API compatible
inputs.
:pr:`28106` by :user:`Thomas Li <lithomas1>`
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`.

**Classes:**

Expand Down
26 changes: 21 additions & 5 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def mean_absolute_error(
Returns
-------
loss : float or ndarray of floats
loss : float or array of floats
If multioutput is 'raw_values', then mean absolute error is returned
for each output separately.
If multioutput is 'uniform_average' or an ndarray of weights, then the
Expand All @@ -213,19 +213,35 @@ def mean_absolute_error(
>>> mean_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
0.85...
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput
input_arrays = [y_true, y_pred, sample_weight, multioutput]
xp, _ = get_namespace(*input_arrays)

dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)

_, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype, xp=xp
)
check_consistent_length(y_true, y_pred, sample_weight)
output_errors = np.average(np.abs(y_pred - y_true), weights=sample_weight, axis=0)

output_errors = _average(
xp.abs(y_pred - y_true), weights=sample_weight, axis=0, xp=xp
)
if isinstance(multioutput, str):
if multioutput == "raw_values":
return output_errors
elif multioutput == "uniform_average":
# pass None as weights to np.average: uniform mean
multioutput = None

return np.average(output_errors, weights=multioutput)
# Average across the outputs (if needed).
mean_absolute_error = _average(output_errors, weights=multioutput)

# Since `y_pred.ndim <= 2` and `y_true.ndim <= 2`, the second call to _average
# should always return a scalar array that we convert to a Python float to
# consistently return the same eager evaluated value, irrespective of the
# Array API implementation.
assert mean_absolute_error.shape == ()
return float(mean_absolute_error)


@validate_params(
Expand Down
11 changes: 11 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,13 @@ def check_array_api_regression_metric_multioutput(
)


def check_array_api_multioutput_regression_metric(
metric, array_namespace, device, dtype_name
):
metric = partial(metric, multioutput="raw_values")
check_array_api_regression_metric(metric, array_namespace, device, dtype_name)


array_api_metric_checkers = {
accuracy_score: [
check_array_api_binary_classification_metric,
Expand All @@ -1893,6 +1900,10 @@ def check_array_api_regression_metric_multioutput(
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
mean_absolute_error: [
check_array_api_regression_metric,
check_array_api_multioutput_regression_metric,
],
}


Expand Down

0 comments on commit 9f44f1f

Please sign in to comment.