Skip to content

Commit

Permalink
Fix KeepSizeByResize crashing
Browse files Browse the repository at this point in the history
This patch fixes `KeepSizeByResize` potentially crashing if a
single numpy array was provided as the input for an iterable
of images (as opposed to a list of numpy arrays).
  • Loading branch information
aleju committed Jan 22, 2020
1 parent b282b97 commit cb6aad2
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
3 changes: 3 additions & 0 deletions changelogs/master/fixed/20200122_fix_keepsizebyresize.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
* Fixed `KeepSizeByResize` potentially crashing if a single numpy array
was provided as the input for an iterable of images (as opposed to
a list of numpy arrays). #590
4 changes: 3 additions & 1 deletion imgaug/augmenters/size.py
Original file line number Diff line number Diff line change
Expand Up @@ -4479,7 +4479,9 @@ def _keep_size_images(cls, images, shapes_orig, images_were_array,
# note here that NO_RESIZE can have led to different shapes
nb_shapes = len({image.shape for image in result})
if nb_shapes == 1:
result = np.array(result, dtype=images.dtype)
# images.dtype does not necessarily work anymore, children
# might have turned 'images' into list
result = np.array(result, dtype=result[0].dtype)

return result

Expand Down
11 changes: 11 additions & 0 deletions test/augmenters/test_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -6707,6 +6707,17 @@ def test_image_interpolation_is_no_resize(self):
assert observed.dtype.type == np.uint8
assert np.allclose(observed, expected)

def test_images_input_is_single_array(self):
# input is single array, children turn in into list of arrays()
# => must be combined to a single output array
images = np.zeros((10, 100, 100), dtype=np.uint8)
aug = iaa.KeepSizeByResize(iaa.Crop((0, 40), keep_size=False))

images_aug = aug(images=images)

assert images.dtype.name == "uint8"
assert images.shape == (10, 100, 100)

def test_keypoints_interpolation_is_cubic(self):
aug = iaa.KeepSizeByResize(self.children, interpolation="cubic")
kpsoi = self.kpsoi
Expand Down

0 comments on commit cb6aad2

Please sign in to comment.