Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Feb 27, 2024
1 parent 6623e2b commit ac0eda1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
9 changes: 4 additions & 5 deletions numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,10 @@ def is_model_stale(self, artifact_data: ArtifactData, payload: StreamPayload) ->
_conf = self.get_ml_pipeline_conf(
config_id=payload.config_id, pipeline_id=payload.pipeline_id
)
if (
self.model_registry.is_artifact_stale(
artifact_data, _conf.numalogic_conf.trainer.retrain_freq_hr
)
and artifact_data.extras.get("source", "registry") == "registry"
if artifact_data.extras.get(
"source", "registry"
) == "registry" and self.model_registry.is_artifact_stale(
artifact_data, _conf.numalogic_conf.trainer.retrain_freq_hr
):
_LOGGER.info(
"%s - Inference artifact found is stale, Keys: %s, Metric: %s",
Expand Down
21 changes: 20 additions & 1 deletion tests/udfs/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def udf_args():


@freeze_time(datetime.now() + timedelta(hours=7))
def test_inference(udf, udf_args, mocker):
def test_inference_01(udf, udf_args, mocker):
mocker.patch.object(
RedisRegistry,
"load",
Expand All @@ -129,6 +129,25 @@ def test_inference(udf, udf_args, mocker):
assert (12, 2) == payload.get_data().shape


@freeze_time(datetime.now() + timedelta(hours=7))
def test_inference_02(udf, udf_args, mocker):
mocker.patch.object(
RedisRegistry,
"load",
return_value=ArtifactData(
artifact=VanillaAE(seq_len=12, n_features=2),
extras=dict(version="0", timestamp=time.time(), source="cache"),
metadata={},
),
)
msgs = udf(*udf_args)
assert len(msgs) == 1
payload = StreamPayload(**orjson.loads(msgs[0].value))
assert Header.MODEL_INFERENCE == payload.header
assert payload.status == Status.ARTIFACT_FOUND
assert (12, 2) == payload.get_data().shape


def test_inference_stale(udf, udf_args, mocker):
mocker.patch.object(
RedisRegistry,
Expand Down

0 comments on commit ac0eda1

Please sign in to comment.