Source code for homura.metrics.segmentation

# some metrics especially for semantic/instance segmentation

import torch
from torch.nn import functional as F

from homura.metrics.commons import confusion_matrix

__all__ = ["binary_as_multiclass", "pixel_accuracy", "mean_iou", "classwise_iou"]


[docs]def binary_as_multiclass(input: torch.Tensor, threshold: float) -> torch.Tensor: """ Convert `Bx1xHxW` outputs to `BxCxHxW`. :param input: :param threshold: :return: """ if input.size(1) != 1: raise RuntimeError(f"Channel dimension is expected to be 1, but got {input.size(1)}") return torch.cat([input.clone().fill_(threshold), input], dim=1)
[docs]def pixel_accuracy(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Pixel accuracy :param input: logits (`BxCxHxW`) :param target: target in LongTensor (`BxHxW`) :return: """ if input.dim() != 4: raise RuntimeError(f"Dimension of input is expected to be 4, but got {input.dim()}") if target.dim() != 3: raise RuntimeError(f"Dimension of target is expected to be 3, but got {target.dim()}") b, c, h, w = input.size() pred = F.one_hot(input.argmax(dim=1), num_classes=c) gt = F.one_hot(target, num_classes=c) acc = (pred * gt).sum(dim=(1, 2, 3)).float() / (w * h) return acc.mean()
[docs]def classwise_iou(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Class-wise IoU :param input: logits (`BxCxHxW`) :param target: target in LongTensor (`BxHxW`) :return: """ if input.dim() != 4: raise RuntimeError(f"Dimension of input is expected to be 4, but got {input.dim()}") if target.dim() != 3: raise RuntimeError(f"Dimension of target is expected to be 3, but got {target.dim()}") cm = confusion_matrix(input, target).float() return cm.diag() / (cm.sum(0) + cm.sum(1) - cm.diag())
[docs]def mean_iou(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Mean IoU :param input: logits (`BxCxHxW`) :param target: target in LongTensor (`BxHxW`) :return: """ return classwise_iou(input, target).mean()