# the design here is inspired by FAIR's fvcore
from __future__ import annotations
import random
import warnings
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Tuple
import torch
from PIL import Image
from torchvision.transforms import functional as VF, transforms as VT
__all__ = ["TransformBase",
"ConcatTransform",
"GeometricTransformBase", "NonGeometricTransformBase",
"RandomResizedCrop", "RandomCrop", "RandomRotation", "RandomHorizontalFlip", "CenterCrop", "RandomResize",
"Normalize", "ColorJitter", "RandomGrayScale"]
TargetType = Literal["bbox", "mask"]
class HomuraTransformWarning(UserWarning):
pass
# utils
# geometric
[docs]class RandomHorizontalFlip(GeometricTransformBase):
def __init__(self,
p: float = 0.5,
target_type: Optional[TargetType] = None
):
super().__init__(target_type)
self.p = p
[docs] def get_params(self, image) -> Optional:
return random.random()
[docs] def apply_coords(self,
coords: torch.Tensor,
original_wh,
params
) -> torch.Tensor:
if params < self.p:
coords[:, 0] = original_wh[0] - coords[:, 0]
return coords
[docs] def apply_image(self,
image: torch.Tensor,
params
) -> torch.Tensor:
return VF.hflip(image) if params < self.p else image
def __repr__(self):
return f"{self.__class__.__name__}(p={self.p})"
def _crop_coods_(coords, top, left, h, w, output_h, output_w):
# crop
coords[:, 0] -= left
coords[:, 1] -= top
coords[:, 0].clamp_(0, w)
coords[:, 1].clamp_(0, h)
# scale
coords[:, 0] *= output_w / w
coords[:, 1] *= output_h / h
return coords.round()
[docs]class RandomCrop(GeometricTransformBase):
def __init__(self,
size,
padding=None,
pad_if_needed=False,
fill=0,
padding_mode="constant",
mask_fill=255,
target_type: Optional[TargetType] = None):
super().__init__(target_type)
self.size = VT._setup_size(size, "Invalid value for size (h, w)")
self.padding = padding
self.pad_if_needed = pad_if_needed
self.padding_mode = padding_mode
self.fill = fill
self.mask_fill = mask_fill
if self.padding is not None and self.target_type is not None:
# when reflection padding is applied, what are the expected mask or bbox?
raise RuntimeError("padding is unexpected for non-classification tasks")
if self.target_type == "detection":
warnings.warn(f"{self.__class__.__name__} expects coordinate origin is at left top. "
f"Inconsistency with this may cause unexpected results.",
HomuraTransformWarning)
[docs] def get_params(self, image) -> Tuple[int, ...]:
return VT.RandomCrop.get_params(image, self.size)
def __call__(self,
input: torch.Tensor,
target: Optional[torch.Tensor] = None
) -> (torch.Tensor, Optional[torch.Tensor]):
if self.padding is not None:
input = VF.pad(input, self.padding, self.fill, self.padding_mode)
if self.pad_if_needed:
w, h = VF._get_image_size(input)
eh, ew = self.size
pw, ph = max(ew - w, 0), max(eh - h, 0)
if pw > 0 or ph > 0:
input = VF.pad(input, [0, 0, pw, ph], fill=self.fill)
if self.target_type == "segmentation":
target = VF.pad(target, [0, 0, pw, ph], fill=self.mask_fill)
return super().__call__(input, target)
[docs] def apply_image(self,
image: torch.Tensor,
params
) -> torch.Tensor:
x, y, h, w = params
return VF.crop(image, x, y, h, w)
[docs] def apply_coords(self,
coords: torch.Tensor,
original_wh,
params
) -> torch.Tensor:
top, left, h, w = params
coords = _crop_coods_(coords, top, left, h, w, self.size[0], self.size[1])
return coords
def __repr__(self):
return f"{self.__class__.__name__}(size={self.size}, pad={self.pad_if_needed})"
[docs]class RandomResize(GeometricTransformBase):
def __init__(self,
min_size: int,
max_size: Optional[int] = None,
target_type: Optional[TargetType] = None):
super().__init__(target_type)
if max_size is not None and min_size > max_size:
raise ValueError(f"Invalid size: min_size={min_size} > max_size={max_size}")
self.min_size = min_size
self.max_size = max_size
[docs] def get_params(self,
image: Optional[torch.Tensor]) -> Optional:
if self.max_size is None:
return self.min_size
return random.randint(self.min_size, self.max_size)
[docs] def apply_image(self,
image: torch.Tensor,
params
) -> torch.Tensor:
return VF.resize(image, params)
[docs] def apply_mask(self,
mask: torch.Tensor,
params
) -> torch.Tensor:
return VF.resize(mask, params, interpolation=Image.NEAREST)
[docs] def apply_coords(self,
coords: torch.Tensor,
original_wh: Tuple[int, int],
params
) -> torch.Tensor:
raise NotImplementedError()
def __repr__(self):
return f"RandomResize(min_size={self.min_size}, max_size={self.max_size})"
[docs]class RandomResizedCrop(GeometricTransformBase):
def __init__(self,
size,
scale=(0.08, 1.0),
ratio=(3. / 4., 4. / 3.),
target_type=None):
super().__init__(target_type=target_type)
self.size = VT._setup_size(size, "Invalid value for size (h, w)")
self.scale = scale
self.ratio = ratio
[docs] def get_params(self,
image: Optional[torch.Tensor]) -> Optional:
return VT.RandomResizedCrop.get_params(image, self.scale, self.ratio)
[docs] def apply_image(self,
image: torch.Tensor,
params
) -> torch.Tensor:
i, j, h, w = params
return VF.resized_crop(image, i, j, h, w, self.size)
[docs] def apply_mask(self,
mask: torch.Tensor,
params
) -> torch.Tensor:
i, j, h, w = params
return VF.resized_crop(mask, i, j, h, w, self.size, interpolation=Image.NEAREST)
[docs] def apply_coords(self,
coords: torch.Tensor,
original_wh,
params
) -> torch.Tensor:
top, left, h, w = params
coords = _crop_coods_(coords, top, left, h, w, self.size[0], self.size[1])
return coords
def __repr__(self):
return f"{self.__class__.__name__}(size={self.size}, scale={self.scale}, ratio={self.ratio})"
[docs]class RandomRotation(GeometricTransformBase):
def __init__(self,
degrees,
fill=None,
mask_fill=255,
target_type=None):
super().__init__(target_type=target_type)
self.degrees = VT._setup_angle(degrees, "degrees", (2,))
self.fill = fill
self.mask_fill = mask_fill
if self.target_type == "detection":
warnings.warn("Rotated bbox may exceeds image area. Please check it carefully.", HomuraTransformWarning)
[docs] def get_params(self,
image: Optional[torch.Tensor]) -> Optional:
return VT.RandomRotation.get_params(self.degrees)
[docs] def apply_image(self,
image: torch.Tensor,
params
) -> torch.Tensor:
angle = params
return VF.rotate(image, angle, fill=self.fill)
[docs] def apply_mask(self,
mask: torch.Tensor,
params
) -> torch.Tensor:
angle = params
return VF.rotate(mask, angle, fill=self.mask_fill)
[docs] def apply_coords(self,
coords: torch.Tensor,
original_wh,
params
) -> torch.Tensor:
original_wh = torch.tensor(original_wh, dtype=torch.float).view(1, 2)
rad = torch.deg2rad(torch.tensor(params, dtype=torch.float))
# rotation matrix
rot = torch.stack([torch.cos(rad), -torch.sin(rad), torch.sin(rad), torch.cos(rad)]).view(2, 2)
center = original_wh / 2
coords -= center
coords @= rot
coords += center
return coords.round()
def __repr__(self):
return f"{self.__class__.__name__}(degrees={self.degrees})"
[docs]class CenterCrop(GeometricTransformBase):
def __init__(self,
size,
target_type=None):
super().__init__(target_type)
self.size = VT._setup_size(size, "Invalid size for (h, w) for size")
[docs] def apply_image(self,
image: torch.Tensor,
params
) -> torch.Tensor:
return VF.center_crop(image, self.size)
[docs] def apply_coords(self,
coords: torch.Tensor,
original_wh,
params
) -> torch.Tensor:
w, h = original_wh
eh, ew = self.size
crop_top = int((h - eh + 1) * 0.5)
crop_left = int((w - ew + 1) * 0.5)
coords = _crop_coods_(coords, crop_top, crop_left, eh, ew, eh, ew)
return coords
def __repr__(self):
return f"{self.__class__.__name__}(size={self.size})"
# non geometric
[docs]class Normalize(NonGeometricTransformBase):
def __init__(self,
mean: List[float],
std: List[float],
target_type: Optional[TargetType] = None):
super().__init__(target_type)
self.mean = mean
self.std = std
[docs] def apply_image(self,
image: torch.Tensor,
params
) -> torch.Tensor:
return VF.normalize(image, self.mean, self.std)
def __repr__(self):
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
[docs]class RandomGrayScale(NonGeometricTransformBase):
def __init__(self,
p: float = 0.5,
target_type: Optional[TargetType] = None):
super().__init__(target_type)
self._impl = VT.RandomGrayscale(p)
[docs] def apply_image(self,
image: torch.Tensor,
params
) -> torch.Tensor:
return self._impl(image)
def __repr__(self):
return f"{self.__class__.__name__}(p={self._impl.p}"
[docs]class ColorJitter(NonGeometricTransformBase):
def __init__(self,
brightness=0,
contrast=0,
saturation=0,
hue=0,
target_type: Optional[TargetType] = None):
super().__init__(target_type)
self._impl = VT.ColorJitter(brightness, contrast, saturation, hue)
[docs] def apply_image(self,
image: torch.Tensor,
params
) -> torch.Tensor:
return self._impl(image)
def __repr__(self):
return f"{self.__class__.__name__}(brightness={self._impl.brightness}, contrast={self._impl.contrast}, " \
f"saturation={self._impl.saturation}, hue={self._impl.hue})"