diff --git a/pyhealth/metrics/binary.py b/pyhealth/metrics/binary.py index 33029fc5..16075808 100644 --- a/pyhealth/metrics/binary.py +++ b/pyhealth/metrics/binary.py @@ -1,4 +1,5 @@ from typing import List, Optional, Dict + import numpy as np import sklearn.metrics as sklearn_metrics @@ -9,6 +10,43 @@ def binary_metrics_fn( metrics: Optional[List[str]] = None, threshold: float = 0.5, ) -> Dict[str, float]: + """Computes metrics for binary classification. + + User can specify which metrics to compute by passing a list of metric names. + The accepted metric names are: + - pr_auc: area under the precision-recall curve + - roc_auc: area under the receiver operating characteristic curve + - accuracy: accuracy score + - balanced_accuracy: balanced accuracy score (usually used for imbalanced + datasets) + - f1: f1 score + - precision: precision score + - recall: recall score + - cohen_kappa: Cohen's kappa score + - jaccard: Jaccard similarity coefficient score + If no metrics are specified, pr_auc, roc_auc and f1 are computed by default. + + This function calls sklearn.metrics functions to compute the metrics. For + more information on the metrics, please refer to the documentation of the + corresponding sklearn.metrics functions. + + Args: + y_true: True target values of shape (n_samples,). + y_prob: Predicted probabilities of shape (n_samples,). + metrics: List of metrics to compute. Default is ["pr_auc", "roc_auc", "f1"]. + threshold: Threshold for binary classification. Default is 0.5. + + Returns: + Dictionary of metrics whose keys are the metric names and values are + the metric values. + + Examples: + >>> from pyhealth.metrics import binary_metrics_fn + >>> y_true = np.array([0, 0, 1, 1]) + >>> y_prob = np.array([0.1, 0.4, 0.35, 0.8]) + >>> binary_metrics_fn(y_true, y_prob, metrics=["accuracy"]) + {'accuracy': 0.75} + """ if metrics is None: metrics = ["pr_auc", "roc_auc", "f1"] diff --git a/pyhealth/metrics/multiclass.py b/pyhealth/metrics/multiclass.py index 9b7b4be8..f609e2c6 100644 --- a/pyhealth/metrics/multiclass.py +++ b/pyhealth/metrics/multiclass.py @@ -9,6 +9,55 @@ def multiclass_metrics_fn( y_prob: np.ndarray, metrics: Optional[List[str]] = None, ) -> Dict[str, float]: + """Computes metrics for multiclass classification. + + User can specify which metrics to compute by passing a list of metric names. + The accepted metric names are: + - roc_auc_macro_ovo: area under the receiver operating characteristic curve, + macro averaged over one-vs-one multiclass classification + - roc_auc_macro_ovr: area under the receiver operating characteristic curve, + macro averaged over one-vs-rest multiclass classification + - roc_auc_weighted_ovo: area under the receiver operating characteristic curve, + weighted averaged over one-vs-one multiclass classification + - roc_auc_weighted_ovr: area under the receiver operating characteristic curve, + weighted averaged over one-vs-rest multiclass classification + - accuracy: accuracy score + - balanced_accuracy: balanced accuracy score (usually used for imbalanced + datasets) + - f1_micro: f1 score, micro averaged + - f1_macro: f1 score, macro averaged + - f1_weighted: f1 score, weighted averaged + - jaccard_micro: Jaccard similarity coefficient score, micro averaged + - jaccard_macro: Jaccard similarity coefficient score, macro averaged + - jaccard_weighted: Jaccard similarity coefficient score, weighted averaged + - cohen_kappa: Cohen's kappa score + If no metrics are specified, accuracy, f1_macro, and f1_micro are computed + by default. + + This function calls sklearn.metrics functions to compute the metrics. For + more information on the metrics, please refer to the documentation of the + corresponding sklearn.metrics functions. + + Args: + y_true: True target values of shape (n_samples,). + y_prob: Predicted probabilities of shape (n_samples, n_classes). + metrics: List of metrics to compute. Default is ["accuracy", "f1_macro", + "f1_micro"]. + + Returns: + Dictionary of metrics whose keys are the metric names and values are + the metric values. + + Examples: + >>> from pyhealth.metrics import multiclass_metrics_fn + >>> y_true = np.array([0, 1, 2, 2]) + >>> y_prob = np.array([[0.9, 0.05, 0.05], + ... [0.05, 0.9, 0.05], + ... [0.05, 0.05, 0.9], + ... [0.6, 0.2, 0.2]]) + >>> multiclass_metrics_fn(y_true, y_prob, metrics=["accuracy"]) + {'accuracy': 0.75} + """ if metrics is None: metrics = ["accuracy", "f1_macro", "f1_micro"] diff --git a/pyhealth/metrics/multilabel.py b/pyhealth/metrics/multilabel.py index 1befe34d..02707441 100644 --- a/pyhealth/metrics/multilabel.py +++ b/pyhealth/metrics/multilabel.py @@ -10,6 +10,63 @@ def multilabel_metrics_fn( metrics: Optional[List[str]] = None, threshold: float = 0.5, ) -> Dict[str, float]: + """Computes metrics for multilabel classification. + + User can specify which metrics to compute by passing a list of metric names. + The accepted metric names are: + - roc_auc_micro: area under the receiver operating characteristic curve, + micro averaged + - roc_auc_macro: area under the receiver operating characteristic curve, + macro averaged + - roc_auc_weighted: area under the receiver operating characteristic curve, + weighted averaged + - roc_auc_samples: area under the receiver operating characteristic curve, + samples averaged + - pr_auc_micro: area under the precision recall curve, micro averaged + - pr_auc_macro: area under the precision recall curve, macro averaged + - pr_auc_weighted: area under the precision recall curve, weighted averaged + - pr_auc_samples: area under the precision recall curve, samples averaged + - accuracy: accuracy score + - f1_micro: f1 score, micro averaged + - f1_macro: f1 score, macro averaged + - f1_weighted: f1 score, weighted averaged + - f1_samples: f1 score, samples averaged + - precision_micro: precision score, micro averaged + - precision_macro: precision score, macro averaged + - precision_weighted: precision score, weighted averaged + - precision_samples: precision score, samples averaged + - recall_micro: recall score, micro averaged + - recall_macro: recall score, macro averaged + - recall_weighted: recall score, weighted averaged + - recall_samples: recall score, samples averaged + - jaccard_micro: Jaccard similarity coefficient score, micro averaged + - jaccard_macro: Jaccard similarity coefficient score, macro averaged + - jaccard_weighted: Jaccard similarity coefficient score, weighted averaged + - jaccard_samples: Jaccard similarity coefficient score, samples averaged + - hamming_loss: Hamming loss + If no metrics are specified, pr_auc_samples is computed by default. + + This function calls sklearn.metrics functions to compute the metrics. For + more information on the metrics, please refer to the documentation of the + corresponding sklearn.metrics functions. + + Args: + y_true: True target values of shape (n_samples, n_labels). + y_prob: Predicted probabilities of shape (n_samples, n_labels). + metrics: List of metrics to compute. Default is ["pr_auc_samples"]. + threshold: Threshold to binarize the predicted probabilities. Default is 0.5. + + Returns: + Dictionary of metrics whose keys are the metric names and values are + the metric values. + + Examples: + >>> from pyhealth.metrics import multilabel_metrics_fn + >>> y_true = np.array([[0, 1, 1], [1, 0, 1]]) + >>> y_prob = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0.6]]) + >>> multilabel_metrics_fn(y_true, y_prob, metrics=["accuracy"]) + {'accuracy': 0.5} + """ if metrics is None: metrics = ["pr_auc_samples"]