Skip to content

Commit

Permalink
[feat] Add support for animated image formats (aimhubio#1704)
Browse files Browse the repository at this point in the history
  • Loading branch information
devfox-se committed Apr 28, 2022
1 parent 3631b85 commit 0e33bbb
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 34 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@

### Enhancements:

- Add support for animated image formats to Aim Image object (devfox-se)

### Fixes:

## 3.9.x (Unreleased)

### Enhancements:

- Add `Notes Tab` to single run page (arsengit)
- Add the run name to the batch delete and the batch archive modals (VkoHov)
- Increase the scalability of rendering lines in charts (KaroMourad)
Expand Down
83 changes: 53 additions & 30 deletions aim/sdk/objects/image.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging
import os.path

from PIL import Image as PILImage
from PIL import (
Image as PILImage,
ImageSequence as PILImageSequence
)

from io import BytesIO
from itertools import chain, repeat
Expand Down Expand Up @@ -34,19 +37,15 @@ class Image(CustomObject):
quality=85
"""

DEFAULT_IMG_FORMAT = 'png'
FLAG_WARN_RGBA_RGB = False
AIM_NAME = 'aim.image'

def __init__(self, image, caption: str = '', format='png', quality=90, optimize=False):
def __init__(self, image, caption: str = '', format=None, quality=90, optimize=False):
super().__init__()

# normalize jpg
if format.lower() == 'jpg':
# PIL doesn't support 'jpg' key
format = 'jpeg'

params = {
'format': format.lower(),
'format': format,
'quality': quality,
'optimize': optimize
}
Expand Down Expand Up @@ -136,27 +135,51 @@ def _from_pil_image(self, pil_image: PILImage.Image, params):
assert isinstance(pil_image, PILImage.Image)
img_container = BytesIO()

try:
pil_image.save(img_container, **params)
except OSError as exc:
# The best way to approach this problem is to prepare PIL Image object before hitting this method.
# This block only handles case where RGBA/P/LA/PA mode is mandated to save in RGB
# PIL won't do that automatically, so we have to convert image to RGB before saving it.
# In addition - make transparency "white" before conversion otherwise it will be black.
if pil_image.mode not in ('RGBA', 'LA', 'PA', 'P'):
raise
elif not Image.FLAG_WARN_RGBA_RGB:
logger.warning(f'Failed to save the image due to the following error: {exc}')
logger.warning(f'Attempting to convert mode "{pil_image.mode}" to "RGB"')
Image.FLAG_WARN_RGBA_RGB = True

alpha = pil_image.convert('RGBA').split()[-1] # Get only alpha
background = PILImage.new('RGBA', pil_image.size, (255, 255, 255, 255))
background.paste(pil_image, mask=alpha)
pil_image = background.convert('RGB')

# Retry
pil_image.save(img_container, **params)
if not params['format']:
params['format'] = pil_image.format or self.DEFAULT_IMG_FORMAT
else:
# normalize img format
img_format = params['format'].lower()
if img_format == 'jpg':
# PIL doesn't support 'jpg' key
params['format'] = 'jpeg'
params['format'] = img_format

if getattr(pil_image, "n_frames", 1) > 1:
# is animated
frames = PILImageSequence.all_frames(pil_image)
params.update(
dict(
save_all=True,
append_images=frames[1:],
)
)
frames[0].save(
img_container,
**params
)
else:
try:
pil_image.save(img_container, **params)
except OSError as exc:
# The best way to approach this problem is to prepare PIL Image object before hitting this method.
# This block only handles case where RGBA/P/LA/PA mode is mandated to save in RGB
# PIL won't do that automatically, so we have to convert image to RGB before saving it.
# In addition - make transparency "white" before conversion otherwise it will be black.
if pil_image.mode not in ('RGBA', 'LA', 'PA', 'P'):
raise exc
elif not Image.FLAG_WARN_RGBA_RGB:
logger.warning(f'Failed to save the image due to the following error: {exc}')
logger.warning(f'Attempting to convert mode "{pil_image.mode}" to "RGB"')
Image.FLAG_WARN_RGBA_RGB = True

alpha = pil_image.convert('RGBA').split()[-1] # Get only alpha
background = PILImage.new('RGBA', pil_image.size, (255, 255, 255, 255))
background.paste(pil_image, mask=alpha)
pil_image = background.convert('RGB')

# Retry
pil_image.save(img_container, **params)

self.storage['data'] = BLOB(data=img_container.getvalue())
self.storage['source'] = 'PIL.Image'
Expand Down Expand Up @@ -245,7 +268,7 @@ def __eq__(self, other):
if self.storage[p] != other.storage[p]:
return False

return (self.storage['data'].load() == other.storage['data'].load())
return self.storage['data'].load() == other.storage['data'].load()


def convert_to_aim_image_list(images, labels=None) -> List[Image]:
Expand Down
10 changes: 6 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ def decode_encoded_tree_stream(stream: Iterator[bytes], concat_chunks=False) ->


def generate_image_set(img_count, caption_prefix='Image', img_size=(16, 16)):
return [AimImage(
pil_image.fromarray((numpy.random.rand(img_size[0], img_size[1], 3) * 255).astype('uint8')),
f'{caption_prefix} {idx}'
) for idx in range(img_count)]
return [
AimImage(
pil_image.fromarray((numpy.random.rand(img_size[0], img_size[1], 3) * 255).astype('uint8')),
caption=f'{caption_prefix} {idx}'
) for idx in range(img_count)
]


def truncate_structured_db(db):
Expand Down

0 comments on commit 0e33bbb

Please sign in to comment.