Skip to content

Commit

Permalink
Postproc support for None (#375)
Browse files Browse the repository at this point in the history
Support for None for postproc.

---------

Signed-off-by: Kushal Batra <[email protected]>
  • Loading branch information
s0nicboOm committed May 3, 2024
1 parent e1ae3ee commit 141e9a0
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 7 deletions.
2 changes: 1 addition & 1 deletion numalogic/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class NumalogicConf:
trainer: TrainerConf = field(default_factory=TrainerConf)
preprocess: list[ModelInfo] = field(default_factory=list)
threshold: ModelInfo = field(default_factory=lambda: ModelInfo(name="StdDevThreshold"))
postprocess: ModelInfo = field(
postprocess: Optional[ModelInfo] = field(
default_factory=lambda: ModelInfo(name="TanhNorm", stateful=False)
)
score: ScoreConf = field(default_factory=lambda: ScoreConf())
Expand Down
6 changes: 5 additions & 1 deletion numalogic/udfs/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
load_latest=LOAD_LATEST,
vertex=self._vtx,
)
postproc_tx = self.postproc_factory.get_instance(postprocess_cfg)
postproc_tx = (
self.postproc_factory.get_instance(postprocess_cfg) if postprocess_cfg else None
)
if not postproc_tx:
logger.info("Postprocess model is absent!")

if thresh_artifact is None:
payload = replace(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.9.1a8"
version = "0.9.1a9"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
5 changes: 2 additions & 3 deletions tests/connectors/test_druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def group_by(*_, **__):
"env": "prod",
"status": 200,
"http_status": "2xx",
"count": 20
"count": 20,
},
"timestamp": "2023-09-06T07:50:00.000Z",
"version": "v1",
Expand All @@ -120,8 +120,7 @@ def group_by(*_, **__):
"env": "prod",
"status": 500,
"http_status": "5xx",
"count": 10

"count": 10,
},
"timestamp": "2023-09-06T07:53:00.000Z",
"version": "v1",
Expand Down
11 changes: 10 additions & 1 deletion tests/udfs/test_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from numalogic._constants import TESTS_DIR
from numalogic.models.threshold import StdDevThreshold
from numalogic.registry import RedisRegistry, ArtifactData
from numalogic.transforms import TanhNorm
from numalogic.udfs import PipelineConf
from numalogic.udfs.entities import Header, TrainerPayload, Status, OutputPayload
from numalogic.udfs.postprocess import PostprocessUDF
Expand Down Expand Up @@ -172,7 +173,15 @@ def test_postprocess_runtime_err_02(udf, mocker, bad_artifact):
assert msgs[1].tags == ["staticthresh"]


def test_compute(udf, artifact):
def test_compute_without_postproc(udf, artifact):
y_unified, x_inferred = udf.compute(artifact.artifact, np.asarray(DATA["data"]))
assert isinstance(y_unified, float)
assert x_inferred.shape == (2,)


def test_compute_with_postproc(udf, artifact):
y_unified, x_inferred = udf.compute(
artifact.artifact, np.asarray(DATA["data"]), postproc_tx=TanhNorm()
)
assert isinstance(y_unified, float)
assert x_inferred.shape == (2,)

0 comments on commit 141e9a0

Please sign in to comment.