Skip to content

Commit

Permalink
Merge pull request #2339 from xiexinch/resize-shortest-edge
Browse files Browse the repository at this point in the history
[Feature] Add ResizeShortestEdge transform
  • Loading branch information
MeowZheng committed Dec 1, 2022
2 parents 383826f + 3b731ed commit 0cdab72
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 6 deletions.
2 changes: 1 addition & 1 deletion mmseg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .version import __version__, version_info

MMCV_MIN = '2.0.0rc1'
MMCV_MIN = '2.0.0rc3'
MMCV_MAX = '2.1.0'
MMENGINE_MIN = '0.1.0'
MMENGINE_MAX = '1.0.0'
Expand Down
5 changes: 3 additions & 2 deletions mmseg/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
PackSegInputs, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
ResizeToMultiple, RGB2Gray, SegRescale)
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)
from .voc import PascalVOCDataset

__all__ = [
Expand All @@ -36,5 +37,5 @@
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset'
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge'
]
8 changes: 5 additions & 3 deletions mmseg/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
LoadImageFromNDArray)
from .transforms import (CLAHE, AdjustGamma, GenerateEdge,
PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomMosaic, RandomRotate, Rerange, ResizeToMultiple,
RGB2Gray, SegRescale)
RandomMosaic, RandomRotate, Rerange,
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)

__all__ = [
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge'
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'ResizeShortestEdge'
]
84 changes: 84 additions & 0 deletions mmseg/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,3 +1226,87 @@ def __repr__(self):
repr_str += f'edge_width={self.edge_width}, '
repr_str += f'ignore_index={self.ignore_index})'
return repr_str


@TRANSFORMS.register_module()
class ResizeShortestEdge(BaseTransform):
"""Resize the image and mask while keeping the aspect ratio unchanged.
Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501
Copyright (c) Facebook, Inc. and its affiliates.
Licensed under the Apache-2.0 License
This transform attempts to scale the shorter edge to the given
`scale`, as long as the longer edge does not exceed `max_size`.
If `max_size` is reached, then downscale so that the longer
edge does not exceed `max_size`.
Required Keys:
- img
- gt_seg_map (optional)
Modified Keys:
- img
- img_shape
- gt_seg_map (optional))
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
scale (Union[int, Tuple[int, int]]): The target short edge length.
If it's tuple, will select the min value as the short edge length.
max_size (int): The maximum allowed longest edge length.
"""

def __init__(self, scale: Union[int, Tuple[int, int]],
max_size: int) -> None:
super().__init__()
self.scale = scale
self.max_size = max_size

# Create a empty Resize object
self.resize = TRANSFORMS.build({
'type': 'Resize',
'scale': 0,
'keep_ratio': True
})

def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]:
"""Compute the target image shape with the given `short_edge_length`.
Args:
img (np.ndarray): The input image.
short_edge_length (Union[int, Tuple[int, int]]): The target short
edge length. If it's tuple, will select the min value as the
short edge length.
"""
h, w = img.shape[:2]
if isinstance(short_edge_length, int):
size = short_edge_length * 1.0
elif isinstance(short_edge_length, tuple):
size = min(short_edge_length) * 1.0
scale = size / min(h, w)
if h < w:
new_h, new_w = size, scale * w
else:
new_h, new_w = scale * h, size

if max(new_h, new_w) > self.max_size:
scale = self.max_size * 1.0 / max(new_h, new_w)
new_h *= scale
new_w *= scale

new_h = int(new_h + 0.5)
new_w = int(new_w + 0.5)
return (new_w, new_h)

def transform(self, results: Dict) -> Dict:
self.resize.scale = self._get_output_shape(results['img'], self.scale)
return self.resize(results)
31 changes: 31 additions & 0 deletions tests/test_datasets/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from mmseg.datasets.transforms import * # noqa
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
from mmseg.registry import TRANSFORMS
from mmseg.utils import register_all_modules

register_all_modules()


def test_resize():
Expand Down Expand Up @@ -71,6 +74,34 @@ def test_resize():
resized_results = resize_module(results.copy())
assert max(resized_results['img_shape'][:2]) <= 1333 * 1.1

# test RandomChoiceResize, which `resize_type` is `ResizeShortestEdge`
transform = dict(
type='RandomChoiceResize',
scales=[128, 256, 512],
resize_type='ResizeShortestEdge',
max_size=1333)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'][0] in [128, 256, 512]

transform = dict(
type='RandomChoiceResize',
scales=[512],
resize_type='ResizeShortestEdge',
max_size=512)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'][1] == 512

transform = dict(
type='RandomChoiceResize',
scales=[(128, 256), (256, 512), (512, 1024)],
resize_type='ResizeShortestEdge',
max_size=1333)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'][0] in [128, 256, 512]

# test scale=None and scale_factor is tuple.
# img shape: (288, 512, 3)
transform = dict(
Expand Down

0 comments on commit 0cdab72

Please sign in to comment.