Skip to content

Commit

Permalink
fix: use ckeys aligning with config in pre, post and inference vtx
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Nov 21, 2023
1 parent b761c94 commit 5aa4c13
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 21 deletions.
7 changes: 5 additions & 2 deletions numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,12 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
# Forward payload if a training request is tagged
if payload.header == Header.TRAIN_REQUEST:
return Messages(Message(keys=keys, value=payload.to_json()))

_conf = self.get_conf(payload.config_id)

artifact_data, payload = _load_artifact(
skeys=keys,
dkeys=[self.get_conf(payload.config_id).numalogic_conf.model.name],
skeys=[_ckey for _, _ckey in zip(_conf.composite_keys, payload.composite_keys)],
dkeys=[_conf.numalogic_conf.model.name],
payload=payload,
model_registry=self.model_registry,
load_latest=LOAD_LATEST,
Expand Down
9 changes: 5 additions & 4 deletions numalogic/udfs/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,13 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload = StreamPayload(**orjson.loads(datum.value))

# load configs
thresh_cfg = self.get_conf(payload.config_id).numalogic_conf.threshold
postprocess_cfg = self.get_conf(payload.config_id).numalogic_conf.postprocess
_conf = self.get_conf(payload.config_id)
thresh_cfg = _conf.numalogic_conf.threshold
postprocess_cfg = _conf.numalogic_conf.postprocess

# load artifact
thresh_artifact, payload = _load_artifact(
skeys=keys,
skeys=[_ckey for _, _ckey in zip(_conf.composite_keys, payload.composite_keys)],
dkeys=[thresh_cfg.name],
payload=payload,
model_registry=self.model_registry,
Expand Down Expand Up @@ -137,7 +138,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
# Forward payload if a training request is tagged
if payload.header == Header.TRAIN_REQUEST or payload.status == Status.ARTIFACT_STALE:
_conf = self.get_conf(payload.config_id)
ckeys = [_item[1] for _item in zip(_conf.composite_keys, payload.composite_keys)]
ckeys = [_ckey for _, _ckey in zip(_conf.composite_keys, payload.composite_keys)]
train_payload = TrainerPayload(
uuid=payload.uuid,
composite_keys=ckeys,
Expand Down
21 changes: 9 additions & 12 deletions numalogic/udfs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional

import orjson
from numpy._typing import NDArray
from numpy.typing import NDArray
from pynumaflow.mapper import Datum, Messages, Message
from sklearn.pipeline import make_pipeline

Expand Down Expand Up @@ -97,15 +97,12 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload = make_stream_payload(data_payload, raw_df, timestamps, keys)

# Check if model will be present in registry
if any(
[_conf.stateful for _conf in self.get_conf(payload.config_id).numalogic_conf.preprocess]
):

_conf = self.get_conf(payload.config_id)
if any(_cfg.stateful for _cfg in _conf.numalogic_conf.preprocess):
preproc_artifact, payload = _load_artifact(
skeys=keys,
dkeys=[
_conf.name
for _conf in self.get_conf(payload.config_id).numalogic_conf.preprocess
],
skeys=[_ckey for _, _ckey in zip(_conf.composite_keys, payload.composite_keys)],
dkeys=[_cfg.name for _cfg in _conf.numalogic_conf.preprocess],
payload=payload,
model_registry=self.model_registry,
load_latest=LOAD_LATEST,
Expand All @@ -123,13 +120,13 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload, status=Status.ARTIFACT_NOT_FOUND, header=Header.TRAIN_REQUEST
)
return Messages(Message(keys=keys, value=payload.to_json()))

# Model will not be in registry
else:
# Load configuration for the config_id
_LOGGER.info("%s - Initializing model from config: %s", payload.uuid, payload)
preproc_clf = self._load_model_from_config(
self.get_conf(payload.config_id).numalogic_conf.preprocess
)
preproc_clf = self._load_model_from_config(_conf.numalogic_conf.preprocess)

try:
x_scaled = self.compute(model=preproc_clf, input_=payload.get_data())
payload = replace(
Expand Down
2 changes: 1 addition & 1 deletion numalogic/udfs/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
numalogic_cfg=_conf.numalogic_conf,
)

# Save artifacts`
# Save artifacts
skeys = payload.composite_keys

self.artifacts_to_save(
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.6.1.dev2"
version = "0.6.1.dev3"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down Expand Up @@ -108,7 +108,7 @@ exclude = '''
line-length = 100
src = ["numalogic", "tests"]
select = ["E", "F", "W", "C901", "NPY", "RUF", "TRY", "G", "PLE", "PLW", "UP", "ICN", "RET", "Q" , "PLR", "D"]
ignore = ["TRY003", "TRY301", "RUF100", "D100", "D104", "PLR2004", "D102", "D401", "D107", "D205", "D105", "PLW0603"]
ignore = ["TRY003", "TRY301", "RUF100", "D100", "D104", "PLR2004", "D102", "D401", "D107", "D205", "D105", "PLW0603", "PLR0915"]
target-version = "py39"
show-fixes = true
show-source = true
Expand Down

0 comments on commit 5aa4c13

Please sign in to comment.