Source code for homura.modules.functional.loss

import warnings

import torch


def _reduction(input: torch.Tensor, reduction: str) -> torch.Tensor:
    if reduction == "mean":
        return input.mean()
    elif reduction == "sum":
        return input.sum()
    elif reduction == "none" or reduction is None:
        return input
    else:
        raise NotImplementedError(f"Wrong reduction: {reduction}")


[docs]def cross_entropy_with_softlabels(input: torch.Tensor, target: torch.Tensor, dim: int = 1, reduction: str = "mean") -> torch.Tensor: """ :param input: :param target: :param dim: :param reduction: :return: """ if hasattr(torch.nn.CrossEntropyLoss, "label_smoothing"): warnings.warn("Use PyTorch's F.cross_entropy", DeprecationWarning) if input.size() != target.size(): raise RuntimeError(f"Input size ({input.size()}) and target size ({target.size()}) should be same!") return _reduction(-(input.log_softmax(dim=dim) * target).sum(dim=dim), reduction)
[docs]def cross_entropy_with_smoothing(input: torch.Tensor, target: torch.Tensor, smoothing: float, dim: int = 1, reduction: str = "mean" ) -> torch.Tensor: """ :param input: :param target: :param smoothing: :param dim: :param reduction: :return: """ if hasattr(torch.nn.CrossEntropyLoss, "label_smoothing"): warnings.warn("Use PyTorch's F.cross_entropy", DeprecationWarning) log_prob = input.log_softmax(dim=dim) nll_loss = -log_prob.gather(dim=dim, index=target.unsqueeze(dim=dim)) nll_loss = nll_loss.squeeze(dim=dim) smooth_loss = -log_prob.mean(dim=dim) loss = (1 - smoothing) * nll_loss + smoothing * smooth_loss return _reduction(loss, reduction)