from functools import partial
import torch
from torch import nn
from .functional import cross_entropy_with_smoothing, cross_entropy_with_softlabels
class _LossFunction(nn.Module):
def forward(self,
input: torch.Tensor,
target: torch.Tensor
) -> torch.Tensor:
return self.impl(input, target)
[docs]class SoftLabelCrossEntropy(_LossFunction):
def __init__(self,
dim: int = 1,
reduction: str = "mean"):
super().__init__()
self.impl = partial(cross_entropy_with_softlabels, dim=dim, reduction=reduction)
[docs]class SmoothedCrossEntropy(_LossFunction):
def __init__(self,
smoothing: float = 0.1,
dim: int = 1,
reduction: str = "mean"):
super().__init__()
self.impl = partial(cross_entropy_with_smoothing, smoothing=smoothing, dim=dim, reduction=reduction)