Skip to content

Commit

Permalink
feat: redis registry (#170)
Browse files Browse the repository at this point in the history
Signed-off-by: s0nicboOm <[email protected]>
  • Loading branch information
s0nicboOm committed May 4, 2023
1 parent 794ddc6 commit f1909a8
Show file tree
Hide file tree
Showing 16 changed files with 632 additions and 70 deletions.
4 changes: 2 additions & 2 deletions numalogic/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.


from numalogic.config._config import NumalogicConf, ModelInfo, LightningTrainerConf, RegistryConf
from numalogic.config._config import NumalogicConf, ModelInfo, LightningTrainerConf, RegistryInfo
from numalogic.config.factory import (
ModelFactory,
PreprocessFactory,
Expand All @@ -23,7 +23,7 @@
"NumalogicConf",
"ModelInfo",
"LightningTrainerConf",
"RegistryConf",
"RegistryInfo",
"ModelFactory",
"PreprocessFactory",
"PostprocessFactory",
Expand Down
13 changes: 9 additions & 4 deletions numalogic/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,17 @@ class ModelInfo:


@dataclass
class RegistryConf:
# TODO implement this
class RegistryInfo:
"""
Registry config base class
Args:
name: name of the registry
conf: kwargs for instantiating the model class
"""
pass

name: str = MISSING
conf: dict[str, Any] = field(default_factory=dict)


@dataclass
Expand Down Expand Up @@ -71,7 +76,7 @@ class NumalogicConf:

model: ModelInfo = field(default_factory=ModelInfo)
trainer: LightningTrainerConf = field(default_factory=LightningTrainerConf)
registry: RegistryConf = field(default_factory=RegistryConf)
registry: RegistryInfo = field(default_factory=RegistryInfo)
preprocess: list[ModelInfo] = field(default_factory=list)
threshold: ModelInfo = field(default_factory=ModelInfo)
postprocess: ModelInfo = field(default_factory=ModelInfo)
26 changes: 18 additions & 8 deletions numalogic/config/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler, RobustScaler

from numalogic.config._config import ModelInfo
from numalogic.config._config import ModelInfo, RegistryInfo
from numalogic.models.autoencoder.variants import (
VanillaAE,
SparseVanillaAE,
Expand All @@ -24,27 +26,28 @@
from numalogic.models.threshold import StdDevThreshold, StaticThreshold, SigmoidThreshold
from numalogic.postprocess import TanhNorm, ExpMovingAverage
from numalogic.preprocess import LogTransformer, StaticPowerTransformer, TanhScaler
from numalogic.registry import MLflowRegistry, RedisRegistry
from numalogic.tools.exceptions import UnknownConfigArgsError


class _ObjectFactory:
_CLS_MAP = {}

def get_instance(self, model_info: ModelInfo):
def get_instance(self, object_info: Union[ModelInfo, RegistryInfo]):
try:
_cls = self._CLS_MAP[model_info.name]
_cls = self._CLS_MAP[object_info.name]
except KeyError as err:
raise UnknownConfigArgsError(
f"Invalid model info instance provided: {model_info}"
f"Invalid model info instance provided: {object_info}"
) from err
return _cls(**model_info.conf)
return _cls(**object_info.conf)

def get_cls(self, model_info: ModelInfo):
def get_cls(self, object_info: Union[ModelInfo, RegistryInfo]):
try:
return self._CLS_MAP[model_info.name]
return self._CLS_MAP[object_info.name]
except KeyError as err:
raise UnknownConfigArgsError(
f"Invalid model info instance provided: {model_info}"
f"Invalid model info instance provided: {object_info}"
) from err


Expand Down Expand Up @@ -83,3 +86,10 @@ class ModelFactory(_ObjectFactory):
"TransformerAE": TransformerAE,
"SparseTransformerAE": SparseTransformerAE,
}


class RegistryFactory(_ObjectFactory):
_CLS_MAP = {
"RedisRegistry": RedisRegistry,
"MLflowRegistry": MLflowRegistry,
}
2 changes: 2 additions & 0 deletions numalogic/registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

try:
from numalogic.registry.mlflow_registry import MLflowRegistry
from numalogic.registry.redis_registry import RedisRegistry
except ImportError:
__all__ = ["ArtifactManager", "ArtifactData", "ArtifactCache", "LocalLRUCache"]
else:
Expand All @@ -24,4 +25,5 @@
"MLflowRegistry",
"ArtifactCache",
"LocalLRUCache",
"RedisRegistry",
]
14 changes: 14 additions & 0 deletions numalogic/registry/_serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import io
import torch


# TODO: ADD other techniques and support for other serialization techniques
def dumps(deserialized_object):
buf = io.BytesIO()
torch.save(deserialized_object, buf)
return buf.getvalue()


def loads(serialized_object):
buffer = io.BytesIO(serialized_object)
return torch.load(buffer)
20 changes: 7 additions & 13 deletions numalogic/registry/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,9 @@


from dataclasses import dataclass
from typing import Any, Generic, TypeVar, Union
from collections.abc import Sequence
from typing import Any, Generic, TypeVar

from numalogic.tools.types import artifact_t, S_KEYS, D_KEYS

META_T = TypeVar("META_T", bound=dict[str, Union[str, list, dict]])
EXTRA_T = TypeVar("EXTRA_T", bound=dict[str, Union[str, list, dict]])
from numalogic.tools.types import artifact_t, KEYS, META_T, EXTRA_T


@dataclass
Expand All @@ -33,7 +29,7 @@ class ArtifactData:
M_K = TypeVar("M_K", bound=str)


class ArtifactManager(Generic[S_KEYS, D_KEYS, A_D]):
class ArtifactManager(Generic[KEYS, A_D]):
"""
Abstract base class for artifact save, load and delete.
Expand All @@ -46,7 +42,7 @@ def __init__(self, uri: str):
self.uri = uri

def load(
self, skeys: Sequence[str], dkeys: Sequence[str], latest: bool = True, version: str = None
self, skeys: KEYS, dkeys: KEYS, latest: bool = True, version: str = None
) -> ArtifactData:
"""
Loads the desired artifact from mlflow registry and returns it.
Expand All @@ -58,9 +54,7 @@ def load(
"""
raise NotImplementedError("Please implement this method!")

def save(
self, skeys: Sequence[str], dkeys: Sequence[str], artifact: artifact_t, **metadata
) -> Any:
def save(self, skeys: KEYS, dkeys: KEYS, artifact: artifact_t, **metadata: META_T) -> Any:
r"""
Saves the artifact into mlflow registry and updates version.
Args:
Expand All @@ -71,7 +65,7 @@ def save(
"""
raise NotImplementedError("Please implement this method!")

def delete(self, skeys: Sequence[str], dkeys: Sequence[str], version: str) -> None:
def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None:
"""
Deletes the artifact with a specified version from mlflow registry.
Args:
Expand All @@ -82,7 +76,7 @@ def delete(self, skeys: Sequence[str], dkeys: Sequence[str], version: str) -> No
raise NotImplementedError("Please implement this method!")

@staticmethod
def construct_key(skeys: Sequence[str], dkeys: Sequence[str]) -> str:
def construct_key(skeys: KEYS, dkeys: KEYS) -> str:
"""
Returns a single key comprising static and dynamic key fields.
Override this method if customization is needed.
Expand Down
21 changes: 9 additions & 12 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import logging
from enum import Enum
from typing import Optional, Any
from collections.abc import Sequence

import mlflow.pyfunc
import mlflow.pytorch
Expand All @@ -25,7 +24,7 @@
from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.registry.artifact import ArtifactCache
from numalogic.tools.exceptions import ModelVersionError
from numalogic.tools.types import artifact_t
from numalogic.tools.types import artifact_t, KEYS, META_T

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -136,7 +135,7 @@ def _clear_cache(self, key: str) -> Optional[ArtifactData]:
return None

def load(
self, skeys: Sequence[str], dkeys: Sequence[str], latest: bool = True, version: str = None
self, skeys: KEYS, dkeys: KEYS, latest: bool = True, version: str = None
) -> Optional[ArtifactData]:
model_key = self.construct_key(skeys, dkeys)

Expand Down Expand Up @@ -180,11 +179,11 @@ def load(

def save(
self,
skeys: Sequence[str],
dkeys: Sequence[str],
skeys: KEYS,
dkeys: KEYS,
artifact: artifact_t,
run_id: str = None,
**metadata: str,
**metadata: META_T,
) -> Optional[ModelVersion]:
"""
Saves the artifact into mlflow registry and updates version.
Expand Down Expand Up @@ -214,7 +213,7 @@ def save(
finally:
mlflow.end_run()

def delete(self, skeys: Sequence[str], dkeys: Sequence[str], version: str) -> None:
def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None:
"""
Deletes the artifact with a specified version from mlflow registry.
Args:
Expand All @@ -234,9 +233,7 @@ def delete(self, skeys: Sequence[str], dkeys: Sequence[str], version: str) -> No
else:
self._clear_cache(model_key)

def transition_stage(
self, skeys: Sequence[str], dkeys: Sequence[str]
) -> Optional[ModelVersion]:
def transition_stage(self, skeys: KEYS, dkeys: KEYS) -> Optional[ModelVersion]:
"""
Changes stage information for the given model. Sets new model to "Production". The old
production model is set to "Staging" and the rest model versions are set to "Archived".
Expand Down Expand Up @@ -276,7 +273,7 @@ def transition_stage(
_LOGGER.info("Successfully transitioned model to Production stage")
return latest_model_data

def __delete_stale_models(self, skeys: Sequence[str], dkeys: Sequence[str]):
def __delete_stale_models(self, skeys: KEYS, dkeys: KEYS):
model_name = self.construct_key(skeys, dkeys)
list_model_versions = list(self.client.search_model_versions(f"name='{model_name}'"))
if len(list_model_versions) > self.models_to_retain:
Expand All @@ -286,7 +283,7 @@ def __delete_stale_models(self, skeys: Sequence[str], dkeys: Sequence[str]):
_LOGGER.debug("Deleted stale model version : %s", stale_model.version)

def __load_artifacts(
self, skeys: Sequence[str], dkeys: Sequence[str], version_info: ModelVersion
self, skeys: KEYS, dkeys: KEYS, version_info: ModelVersion
) -> tuple[artifact_t, dict[str, Any]]:
model_key = self.construct_key(skeys, dkeys)
model = self.handler.load_model(model_uri=f"models:/{model_key}/{version_info.version}")
Expand Down
Loading

0 comments on commit f1909a8

Please sign in to comment.