Source code for homura.metrics.commons

from typing import Tuple

import torch
from torch import Tensor

from homura.liblog import get_logger

logger = get_logger(__name__)

__all__ = ["true_positive", "true_negative", "false_positive", "false_negative",
           "classwise_accuracy", "precision", "recall", "specificity", "f1_score",
           "confusion_matrix", "accuracy"]


def _base(input: Tensor,
          target: Tensor
          ) -> Tuple[Tensor, Tensor, Tensor]:
    classes = torch.arange(input.size(1), device=input.device)
    pred = input.argmax(dim=1).view(-1, 1)
    target = target.view(-1, 1)
    return pred, target, classes


[docs]def true_positive(input: Tensor, target: Tensor ) -> Tensor: """Calculate true positive :param input: output of network, expected to be `BxCx(OPTIONAL DIMENSIONS)` :param target: target, expected to be `Bx(OPTIONAL DIMENSIONS)` :return: true positive in float tensor of `C` """ pred, target, classes = _base(input, target) out = (pred == classes) & (target == classes) return out.sum(dim=0).float()
[docs]def true_negative(input: Tensor, target: Tensor ) -> Tensor: """Calculate true negative :param input: output of network, expected to be `BxCx(OPTIONAL DIMENSIONS)` :param target: target, expected to be `Bx(OPTIONAL DIMENSIONS)` :return: true negative in float tensor of `C` """ pred, target, classes = _base(input, target) out = ((pred != classes) & (target != classes)) return out.sum(dim=0).float()
[docs]def false_positive(input: Tensor, target: Tensor ) -> Tensor: """Calculate false positive :param input: output of network, expected to be `BxCx(OPTIONAL DIMENSIONS)` :param target: target, expected to be `Bx(OPTIONAL DIMENSIONS)` :return: false positive in float tensor of `C` """ pred, target, classes = _base(input, target) out = ((pred == classes) & (target != classes)) return out.sum(dim=0).float()
[docs]def false_negative(input: Tensor, target: Tensor ) -> Tensor: """Calculate false negative :param input: output of network, expected to be `BxCx(OPTIONAL DIMENSIONS)` :param target: target, expected to be `Bx(OPTIONAL DIMENSIONS)` :return: false negative in float tensor of `C` """ pred, target, classes = _base(input, target) out = ((pred != classes) & (target == classes)) return out.sum(dim=0).float()
[docs]def classwise_accuracy(input: Tensor, target: Tensor ) -> Tensor: """Calculate class wise accuracy :param input: output of network, expected to be `BxCx(OPTIONAL DIMENSIONS)` :param target: target, expected to be `Bx(OPTIONAL DIMENSIONS)` :return: class wise accuracy in float tensor of `C` """ tp = true_positive(input, target) tn = true_negative(input, target) fp = false_positive(input, target) fn = false_negative(input, target) denom = tp + tn + fp + fn if any(denom == 0): logger.warning("Zero division in accuracy") return (tp + tn) / denom
[docs]def precision(input: Tensor, target: Tensor ) -> Tensor: """Calculate precision :param input: output of network, expected to be `BxCx(OPTIONAL DIMENSIONS)` :param target: target, expected to be `Bx(OPTIONAL DIMENSIONS)` :return: precision in float tensor of `C` """ tp = true_positive(input, target) fp = false_positive(input, target) denom = tp + fp if any(denom == 0): logger.warning("Zero division in precision") return tp / denom
[docs]def recall(input: Tensor, target: Tensor ) -> Tensor: """Calculate recall :param input: output of network, expected to be `BxCx(OPTIONAL DIMENSIONS)` :param target: target, expected to be `Bx(OPTIONAL DIMENSIONS)` :return: recall in float tensor of `C` """ tp = true_positive(input, target) fn = false_negative(input, target) denom = tp + fn if any(denom == 0): logger.warning("Zero division in recall") return tp / denom
[docs]def specificity(input: Tensor, target: Tensor ) -> Tensor: """Calculate specificity :param input: output of network, expected to be `BxCx(OPTIONAL DIMENSIONS)` :param target: target, expected to be `Bx(OPTIONAL DIMENSIONS)` :return: specificity in float tensor of `C` """ tn = true_negative(input, target) fp = false_positive(input, target) denom = tn + fp if any(denom == 0): logger.warning("Zero division in specificity") return tn / denom
[docs]def f1_score(input: Tensor, target: Tensor ) -> Tensor: """Calculate f1 score :param input: output of network, expected to be `BxCx(OPTIONAL DIMENSIONS)` :param target: target, expected to be `Bx(OPTIONAL DIMENSIONS)` :return: f1 score in float tensor of `C` """ prec = precision(input, target) rec = recall(input, target) return 2 * prec * rec / (prec + rec)
[docs]def confusion_matrix(input: Tensor, target: Tensor ) -> Tensor: """Calculate confusion matrix :param input: output of network, expected to be `BxCx(OPTIONAL DIMENSIONS)` :param target: target, expected to be `Bx(OPTIONAL DIMENSIONS)` :return: confusion matrix in long tensor of `CxC` """ num_classes = input.size(1) indices = (0 <= target) & (target < num_classes) pred = input.argmax(dim=1)[indices] inds = num_classes * pred + target[indices] return inds.bincount(minlength=num_classes ** 2).view(num_classes, num_classes)
[docs]@torch.no_grad() def accuracy(input: Tensor, target: Tensor, top_k: int = 1 ) -> Tensor: pred_idx = input.argmax(dim=-1, keepdim=True) if top_k == 1 else input.topk(k=top_k, dim=-1).indices target = target.view(-1, 1).expand_as(pred_idx) return (pred_idx == target).float().sum(dim=1).mean()