Skip to content

Commit

Permalink
perf: improve serialization performance (#243)
Browse files Browse the repository at this point in the history
- close the buffer
- use highest pickle protocol by default
- benchmarking in unit tests


|   Change	|   Improvement	|
|---	|---	|
|  Closing buffer 	|   3.2%	|
|  Protocol update 	|   15.5%	|

---------

Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Aug 16, 2023
1 parent 7ef40a0 commit 8482c01
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 9 deletions.
25 changes: 18 additions & 7 deletions numalogic/registry/_serialize.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
import io
import pickle
from typing import Union

import torch

from numalogic.tools.types import artifact_t, state_dict_t


# TODO: ADD other techniques and support for other serialization techniques
def dumps(deserialized_object):
buf = io.BytesIO()
torch.save(deserialized_object, buf)
return buf.getvalue()
def dumps(
deserialized_object: Union[artifact_t, state_dict_t],
pickle_protocol: int = pickle.HIGHEST_PROTOCOL,
) -> bytes:
buffer = io.BytesIO()
torch.save(deserialized_object, buffer, pickle_protocol=pickle_protocol)
serialized_obj = buffer.getvalue()
buffer.close()
return serialized_obj


def loads(serialized_object):
def loads(serialized_object: bytes) -> Union[artifact_t, state_dict_t]:
buffer = io.BytesIO(serialized_object)
return torch.load(buffer)
deserialized_obj = torch.load(buffer)
buffer.close()
return deserialized_obj
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.6.dev0"
version = "0.6.dev1"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
38 changes: 37 additions & 1 deletion tests/registry/test_serialize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import logging
import pickle
import timeit
import unittest

from sklearn.preprocessing import StandardScaler
from torchinfo import summary

from numalogic.models.autoencoder.variants import VanillaAE
from numalogic.registry._serialize import loads, dumps

from numalogic.models.autoencoder.variants import VanillaAE

LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class TestSerialize(unittest.TestCase):
Expand All @@ -21,3 +27,33 @@ def test_dumps_loads2(self):
serialized_obj = dumps(model)
deserialized_obj = loads(serialized_obj)
self.assertEqual(model.mean_, deserialized_obj.mean_)

def test_benchmark_state_dict_vs_model(self):
model = VanillaAE(10, 2)
serialized_sd = dumps(model.state_dict())
serialized_obj = dumps(model)
elapsed_obj = timeit.timeit(lambda: loads(serialized_obj), number=100)
elapsed_sd = timeit.timeit(lambda: loads(serialized_sd), number=100)
try:
self.assertLess(elapsed_sd, elapsed_obj)
except AssertionError:
LOGGER.warning(
"The state_dict time %.3f is more than the model time %.3f",
elapsed_sd,
elapsed_obj,
)

def test_benchmark_protocol(self):
model = VanillaAE(10, 2)
serialized_default = dumps(model, pickle_protocol=1)
serialized_highest = dumps(model, pickle_protocol=pickle.HIGHEST_PROTOCOL)
elapsed_default = timeit.timeit(lambda: loads(serialized_default), number=1000)
elapsed_highest = timeit.timeit(lambda: loads(serialized_highest), number=1000)
try:
self.assertLess(elapsed_highest, elapsed_default)
except AssertionError:
LOGGER.warning(
"The default protocol time %.3f is less than the highest protocol time %.3f",
elapsed_default,
elapsed_highest,
)

0 comments on commit 8482c01

Please sign in to comment.