Skip to content

Commit

Permalink
fix: add caching logs (#203)
Browse files Browse the repository at this point in the history
---------
Signed-off-by: s0nicboOm <[email protected]>
  • Loading branch information
s0nicboOm committed Jun 2, 2023
1 parent 4411aa7 commit 6727296
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
1 change: 1 addition & 0 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def load(
if latest:
cached_artifact = self._load_from_cache(model_key)
if cached_artifact:
_LOGGER.debug("Found cached artifact for key: %s", model_key)
return cached_artifact
version_info = self.client.get_latest_versions(model_key, stages=[self.model_stage])
if not version_info:
Expand Down
1 change: 1 addition & 0 deletions numalogic/registry/redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __get_artifact_data(
def __load_latest_artifact(self, key: str) -> ArtifactData:
cached_artifact = self._load_from_cache(key)
if cached_artifact:
_LOGGER.debug("Found cached artifact for key: %s", key)
return cached_artifact
production_key = self.__construct_production_key(key)
if not self.client.exists(production_key):
Expand Down
27 changes: 26 additions & 1 deletion tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sklearn.ensemble import RandomForestRegressor

from numalogic.models.autoencoder.variants import VanillaAE
from numalogic.registry import MLflowRegistry, ArtifactData
from numalogic.registry import MLflowRegistry, ArtifactData, LocalLRUCache
from numalogic.registry.mlflow_registry import ModelStage
from numalogic.tools.exceptions import ModelVersionError
from tests.registry._mlflow_utils import (
Expand Down Expand Up @@ -338,6 +338,31 @@ def test_no_cache(self):
self.assertIsNone(registry._load_from_cache("key"))
self.assertIsNone(registry._clear_cache("key"))

def test_cache(self):
cache_registry = LocalLRUCache()
registry = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry)
registry._save_in_cache("key", ArtifactData(artifact=self.model, extras={}, metadata={}))
self.assertIsNotNone(registry._load_from_cache("key"))
self.assertIsNotNone(registry._clear_cache("key"))

@patch("mlflow.pytorch.log_model", mock_log_model_pytorch())
@patch("mlflow.start_run", Mock(return_value=ActiveRun(return_pytorch_rundata_dict())))
@patch("mlflow.active_run", Mock(return_value=return_pytorch_rundata_dict()))
@patch("mlflow.log_params", {"lr": 0.01})
@patch("mlflow.tracking.MlflowClient.transition_model_version_stage", mock_transition_stage)
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch("mlflow.pytorch.load_model", Mock(return_value=VanillaAE(10)))
@patch("mlflow.tracking.MlflowClient.get_run", Mock(return_value=return_pytorch_rundata_dict()))
def test_cache_loading(self):
cache_registry = LocalLRUCache(ttl=50000)
ml = MLflowRegistry(TRACKING_URI, cache_registry=cache_registry)
ml.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.model, **{"lr": 0.01})
ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
key = MLflowRegistry.construct_key(self.skeys, self.dkeys)
self.assertIsNotNone(ml._load_from_cache(key))
data = ml.load(skeys=self.skeys, dkeys=self.dkeys, artifact_type="pytorch")
self.assertIsNotNone(data)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6727296

Please sign in to comment.