Skip to content

Commit

Permalink
fix (RedisRegistry): avoid overwriting cache with the same key during…
Browse files Browse the repository at this point in the history
… load (#223)

- fix cache key overwrite issue which caused a stale model to persist
for a long time
- Save sourcetype in artifact data extras to be used for other logic
- Improve docs

---------

Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Jul 6, 2023
1 parent 0b0daed commit b21e246
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 28 deletions.
4 changes: 4 additions & 0 deletions numalogic/registry/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class ArtifactManager(Generic[KEYS, A_D]):
uri: server/connection uri
"""

_STORETYPE = "registry"

__slots__ = ("uri",)

def __init__(self, uri: str):
Expand Down Expand Up @@ -137,6 +139,8 @@ class ArtifactCache(Generic[M_K, A_D]):
ttl: time to live for each item in the cache
"""

_STORETYPE = "cache"

__slots__ = ("_cachesize", "_ttl")

def __init__(self, cachesize: int, ttl: int):
Expand Down
11 changes: 11 additions & 0 deletions numalogic/registry/dynamodb_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Copyright 2022 The Numaproj Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http:https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import time
from typing import Any, Optional
Expand Down
44 changes: 43 additions & 1 deletion numalogic/registry/localcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Optional

from cachetools import TTLCache
Expand All @@ -34,14 +35,55 @@ def __init__(self, cachesize: int = 512, ttl: int = 300):
if not self.__cache:
self.__cache = TTLCache(maxsize=cachesize, ttl=ttl)

def load(self, artifact_key: str) -> ArtifactData:
def __contains__(self, artifact_key: str) -> bool:
"""Check if an artifact is in the cache."""
return artifact_key in self.__cache

def load(self, artifact_key: str) -> Optional[ArtifactData]:
"""
Load an artifact from the cache.
Args:
----
artifact_key: The key of the artifact to load.
Returns
-------
The artifact data instance if found, None otherwise.
"""
return self.__cache.get(artifact_key)

def save(self, key: str, artifact: ArtifactData) -> None:
"""
Save an artifact to the cache.
Args:
----
key: The key of the artifact to save.
artifact: The artifact data instance to save.
"""
artifact = deepcopy(artifact)
artifact.extras["source"] = self._STORETYPE
self.__cache[key] = artifact

def delete(self, key: str) -> Optional[ArtifactData]:
"""
Delete an artifact from the cache.
Args:
----
key: The key of the artifact to delete.
Returns
-------
The deleted artifact data instance if found, None otherwise.
"""
return self.__cache.pop(key, default=None)

def clear(self) -> None:
"""Clears the whole cache."""
self.__cache.clear()

def keys(self) -> list[str]:
"""Returns the current keys of the cache."""
return list(_key for _key in self.__cache)
82 changes: 68 additions & 14 deletions numalogic/registry/redis_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Copyright 2022 The Numaproj Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http:https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import time
from datetime import datetime, timedelta
Expand All @@ -18,7 +29,7 @@ class RedisRegistry(ArtifactManager):
Args:
----
client: Take in the reids client already established/created
client: Take in the redis client already established/created
ttl: Total Time to Live (in seconds) for the key when saving in redis (dafault = 604800)
cache_registry: Cache registry to use (default = None).
Expand Down Expand Up @@ -94,6 +105,7 @@ def _load_from_cache(self, key: str) -> Optional[ArtifactData]:

def _save_in_cache(self, key: str, artifact_data: ArtifactData) -> None:
if self.cache_registry:
_LOGGER.debug("Saving artifact in cache with key: %s", key)
self.cache_registry.save(key, artifact_data)

def _clear_cache(self, key: Optional[str] = None) -> Optional[ArtifactData]:
Expand Down Expand Up @@ -123,20 +135,38 @@ def __get_artifact_data(
extras={
"timestamp": float(artifact_timestamp.decode()),
"version": artifact_version.decode(),
"source": self._STORETYPE,
},
)

def __load_latest_artifact(self, key: str) -> ArtifactData:
def __load_latest_artifact(self, key: str) -> tuple[ArtifactData, bool]:
"""
Load the latest artifact from the registry.
Args:
key: full model key.
Returns
-------
ArtifactData and a boolean flag indicating if the artifact was loaded from cache.
Raises
------
ModelKeyNotFound: If the model key is not found in the registry.
"""
cached_artifact = self._load_from_cache(key)
if cached_artifact:
_LOGGER.debug("Found cached artifact for key: %s", key)
return cached_artifact
return cached_artifact, True
latest_key = self.__construct_latest_key(key)
if not self.client.exists(latest_key):
raise ModelKeyNotFound(f"latest key: {latest_key}, Not Found !!!")
model_key = self.client.get(latest_key)
_LOGGER.info("latest key, %s, is pointing to the key : %s", latest_key, model_key)
return self.__load_version_artifact(version=self.get_version(model_key.decode()), key=key)
_LOGGER.debug("latest key, %s, is pointing to the key : %s", latest_key, model_key)
return (
self.__load_version_artifact(version=self.get_version(model_key.decode()), key=key),
False,
)

def __load_version_artifact(self, version: str, key: str) -> ArtifactData:
model_key = self.__construct_version_key(key, version)
Expand All @@ -152,7 +182,7 @@ def __save_artifact(
new_version_key = self.__construct_version_key(key, version)
latest_key = self.__construct_latest_key(key)
pipe.set(name=latest_key, value=new_version_key)
_LOGGER.info("Setting latest key : %s ,to this new key = %s", latest_key, new_version_key)
_LOGGER.debug("Setting latest key : %s ,to this new key = %s", latest_key, new_version_key)
serialized_metadata = ""
if metadata:
serialized_metadata = dumps(deserialized_object=metadata)
Expand All @@ -178,6 +208,8 @@ def load(
"""Loads the artifact from redis registry. Either latest or version (one of the arguments)
is needed to load the respective artifact.
If cache registry is provided, it will first check the cache registry for the artifact.
Args:
----
skeys: static key fields as list/tuple of strings
Expand All @@ -188,19 +220,26 @@ def load(
Returns
-------
ArtifactData instance
Raises
------
ValueError: If both latest and version are provided or none of them are provided.
RedisRegistryError: If any redis error occurs.
"""
if (latest and version) or (not latest and not version):
raise ValueError("Either One of 'latest' or 'version' needed in load method call")
key = self.construct_key(skeys, dkeys)
is_cached = False
try:
if latest:
artifact_data = self.__load_latest_artifact(key)
self._save_in_cache(key, artifact_data)
artifact_data, is_cached = self.__load_latest_artifact(key)
else:
artifact_data = self.__load_version_artifact(version, key)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
if (not is_cached) and latest:
self._save_in_cache(key, artifact_data)
return artifact_data

def save(
Expand All @@ -221,24 +260,28 @@ def save(
Returns
-------
model version
Model version (str)
Raises
------
RedisRegistryError: If there is any RedisError while saving the artifact.
"""
key = self.construct_key(skeys, dkeys)
latest_key = self.__construct_latest_key(key)
version = 0
try:
if self.client.exists(latest_key):
_LOGGER.debug("latest key exists for the model")
_LOGGER.debug("Latest key: %s exists for the model", latest_key)
version_key = self.client.get(name=latest_key)
version = int(self.get_version(version_key.decode())) + 1
with self.client.pipeline() as pipe:
new_version_key = self.__save_artifact(pipe, artifact, metadata, key, str(version))
pipe.expire(name=new_version_key, time=self.ttl)
_LOGGER.info("Model with the key = %s, loaded successfully.", new_version_key)
pipe.execute()
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
_LOGGER.info("Model with the key = %s, saved successfully.", new_version_key)
return str(version)

def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None:
Expand All @@ -249,21 +292,25 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None:
skeys: static key fields as list/tuple of strings
dkeys: dynamic key fields as list/tuple of strings
version: model version to delete.
Raises
------
ModelKeyNotFound: If the model version is not found in registry.
RedisRegistryError: If there is any RedisError while deleting the artifact.
"""
key = self.construct_key(skeys, dkeys)
del_key = self.__construct_version_key(key, version)
try:
if self.client.exists(del_key):
self.client.delete(del_key)
_LOGGER.info("Model with the key = %s, deleted successfully", del_key)
else:
_LOGGER.debug("Key to delete: %s, Not Found !!!\n Exiting.....", del_key)
raise ModelKeyNotFound(
"Key to delete: %s, Not Found !!!\n Exiting....." % del_key,
"Key to delete: %s, Not Found!" % del_key,
)
except RedisError as err:
raise RedisRegistryError(f"{err.__class__.__name__} raised") from err
else:
_LOGGER.info("Model with the key = %s, deleted successfully", del_key)
self._clear_cache(del_key)

@staticmethod
Expand All @@ -276,6 +323,13 @@ def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool:
artifact_data: ArtifactData object to look into
freq_hr: Frequency of retraining in hours.
Returns
-------
True if artifact is stale, False otherwise.
Raises
------
RedisRegistryError: If there is any error while fetching timestamp information.
"""
try:
artifact_ts = float(artifact_data.extras["timestamp"])
Expand Down
6 changes: 4 additions & 2 deletions tests/registry/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ def test_cache_size(self):

self.assertIsNone(cache_registry.load("m1"))
self.assertIsInstance(cache_registry.load("m2"), ArtifactData)
self.assertIsInstance(cache_registry.load("m3"), ArtifactData)
self.assertEqual(2, cache_registry.cachesize)
self.assertEqual(1, cache_registry.ttl)
self.assertTrue("m2" in cache_registry)
self.assertTrue("m3" in cache_registry)
self.assertListEqual(["m2", "m3"], cache_registry.keys())

def test_cache_overwrite(self):
cache_registry = LocalLRUCache(cachesize=2, ttl=1)
Expand All @@ -41,7 +43,7 @@ def test_cache_overwrite(self):
)

loaded_artifact = cache_registry.load("m1")
self.assertDictEqual({"version": "2"}, loaded_artifact.extras)
self.assertDictEqual({"version": "2", "source": "cache"}, loaded_artifact.extras)

def test_cache_ttl(self):
cache_registry = LocalLRUCache(cachesize=2, ttl=1)
Expand Down
51 changes: 40 additions & 11 deletions tests/registry/test_redis_registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging
import time
import unittest
from datetime import datetime, timedelta
from unittest.mock import Mock, patch
Expand All @@ -12,6 +14,8 @@
from numalogic.registry import RedisRegistry, LocalLRUCache, ArtifactData
from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError

logging.basicConfig(level=logging.DEBUG)


class TestRedisRegistry(unittest.TestCase):
@classmethod
Expand All @@ -25,7 +29,7 @@ def setUpClass(cls) -> None:
cls.redis_client = fakeredis.FakeStrictRedis(server=server, decode_responses=False)

def setUp(self):
self.cache = LocalLRUCache(cachesize=4, ttl=300)
self.cache = LocalLRUCache(cachesize=4, ttl=1)
self.registry = RedisRegistry(
client=self.redis_client,
cache_registry=self.cache,
Expand Down Expand Up @@ -162,19 +166,44 @@ def test_load_model_when_no_model(self):
with self.assertRaises(ModelKeyNotFound):
self.registry.load(skeys=self.skeys, dkeys=self.dkeys)

def test_load_model_when_model_stale(self):
with self.assertRaises(ModelKeyNotFound):
version = self.registry.save(
skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model
)
self.registry.delete(skeys=self.skeys, dkeys=self.dkeys, version=str(version))
self.registry.load(skeys=self.skeys, dkeys=self.dkeys)
def test_load_latest_model_twice(self):
with freeze_time(datetime.today() - timedelta(days=5)):
self.registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model)

artifact_data_1 = self.registry.load(skeys=self.skeys, dkeys=self.dkeys)
artifact_data_2 = self.registry.load(skeys=self.skeys, dkeys=self.dkeys)
self.assertTrue(self.registry.is_artifact_stale(artifact_data_1, 4))
self.assertEqual("registry", artifact_data_1.extras["source"])
self.assertEqual("cache", artifact_data_2.extras["source"])

def test_load_latest_cache_ttl_expire(self):
self.registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model)
artifact_data_1 = self.registry.load(skeys=self.skeys, dkeys=self.dkeys)
time.sleep(1)
artifact_data_2 = self.registry.load(skeys=self.skeys, dkeys=self.dkeys)
self.assertEqual("registry", artifact_data_1.extras["source"])
self.assertEqual("registry", artifact_data_2.extras["source"])

def test_load_non_latest_model_twice(self):
old_version = self.registry.save(
skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model
)
self.registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model)

artifact_data_1 = self.registry.load(
skeys=self.skeys, dkeys=self.dkeys, latest=False, version=old_version
)
artifact_data_2 = self.registry.load(
skeys=self.skeys, dkeys=self.dkeys, latest=False, version=old_version
)
self.assertEqual("registry", artifact_data_1.extras["source"])
self.assertEqual("registry", artifact_data_2.extras["source"])

def test_delete_version(self):
version = self.registry.save(
skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model
)
with self.assertRaises(ModelKeyNotFound):
version = self.registry.save(
skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model
)
self.registry.delete(skeys=self.skeys, dkeys=self.dkeys, version=str(version))
self.registry.load(skeys=self.skeys, dkeys=self.dkeys)

Expand Down

0 comments on commit b21e246

Please sign in to comment.