from typing import Optional, Tuple
import torch
[docs]def partial_mixup(input: torch.Tensor,
gamma: float,
indices: torch.Tensor
) -> torch.Tensor:
""" mixup: Beyond Empirical Risk Minimization
:param input:
:param gamma:
:param indices:
:return:
"""
if input.size(0) != indices.size(0):
raise RuntimeError("Size mismatch!")
perm_input = input[indices]
return input.mul(gamma).add(perm_input, alpha=1 - gamma)
[docs]def mixup(input: torch.Tensor,
target: torch.Tensor,
gamma: float,
indices: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
""" mixup: Beyond Empirical Risk Minimization
:param input:
:param target:
:param gamma:
:param indices:
:return:
"""
if indices is None:
indices = torch.randperm(input.size(0), device=input.device, dtype=torch.long)
return partial_mixup(input, gamma, indices), partial_mixup(target, gamma, indices)