Skip to content

Commit

Permalink
fix: transition (#144)
Browse files Browse the repository at this point in the history
* fix: transition

Signed-off-by: s0nicboOm <[email protected]>

* fix: flake test error

Signed-off-by: s0nicboOm <[email protected]>

* fix: make private function to delete stale models

Signed-off-by: s0nicboOm <[email protected]>

* fix: patch update

Signed-off-by: s0nicboOm <[email protected]>

* fix: delete stale models

Signed-off-by: s0nicboOm <[email protected]>

* add: logs and exception catch

Signed-off-by: s0nicboOm <[email protected]>

* fix: testcases

Signed-off-by: s0nicboOm <[email protected]>

---------

Signed-off-by: s0nicboOm <[email protected]>
  • Loading branch information
s0nicboOm committed Mar 14, 2023
1 parent 26623a5 commit 6ee3446
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 26 deletions.
48 changes: 24 additions & 24 deletions numalogic/registry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,48 +226,48 @@ def transition_stage(
"""
model_name = self.construct_key(skeys, dkeys)
try:
version = int(self.get_version(model_name=model_name))
current_production = self.client.get_latest_versions(
name=model_name, stages=["Production"]
)
current_staging = self.client.get_latest_versions(name=model_name, stages=["Staging"])
latest = self.client.get_latest_versions(name=model_name, stages=["None"])

latest_model_data = self.client.transition_model_version_stage(
name=model_name,
version=str(version),
version=str(latest[-1].version),
stage=ModelStage.PRODUCTION,
)
if version - 1 > 0:

if current_production:
self.client.transition_model_version_stage(
name=model_name,
version=str(version - 1),
version=str(current_production[-1].version),
stage=ModelStage.STAGE,
)
if version - 2 > 0:

if current_staging:
self.client.transition_model_version_stage(
name=model_name,
version=str(version - 2),
version=str(current_staging[-1].version),
stage=ModelStage.ARCHIVE,
)

# only keep "models_to_retain" number of models.
list_model_versions = list(self.client.search_model_versions(f"name='{model_name}'"))
models_to_delete = list_model_versions[: -self.models_to_retain]
for stale_model in models_to_delete:
self.delete(skeys=skeys, dkeys=dkeys, version=stale_model.version)
self.__delete_stale_models(skeys=skeys, dkeys=dkeys)

_LOGGER.info("Successfully transitioned model to Production stage")
return latest_model_data
except Exception as ex:
except RestException as ex:
_LOGGER.exception(
"Error when transitioning a model: %s to different stage: %r", model_name, ex
)
return None

def get_version(self, model_name: str) -> Optional[ModelVersion]:
"""
Get model's latest version given the model name
Args:
model_name: model name for which the version has to be identified.
Returns:
version from mlflow ModelVersion instance
"""
try:
return self.client.get_latest_versions(name=model_name, stages=[])[-1].version
except RestException as ex:
_LOGGER.error("Error when getting model version: %r", ex)
return None
def __delete_stale_models(self, skeys: Sequence[str], dkeys: Sequence[str]):
model_name = self.construct_key(skeys, dkeys)
list_model_versions = list(self.client.search_model_versions(f"name='{model_name}'"))
if len(list_model_versions) > self.models_to_retain:
models_to_delete = list_model_versions[self.models_to_retain :]
for stale_model in models_to_delete:
self.delete(skeys=skeys, dkeys=dkeys, version=stale_model.version)
_LOGGER.debug("Deleted stale model version : %s", stale_model.version)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.3.5"
version = "0.3.6"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
2 changes: 1 addition & 1 deletion tests/registry/test_mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_load_model_when_no_model_02(self):
@patch("mlflow.tracking.MlflowClient.get_latest_versions", mock_get_model_version)
@patch(
"mlflow.tracking.MlflowClient.transition_model_version_stage",
Mock(side_effect=RuntimeError),
Mock(side_effect=RestException({"error_code": ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)})),
)
def test_transition_stage_fail(self):
fake_skeys = ["Fakemodel_"]
Expand Down

0 comments on commit 6ee3446

Please sign in to comment.