Skip to content

Commit

Permalink
fix: comments
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Feb 6, 2024
1 parent 780f2e2 commit b41a7f7
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
4 changes: 0 additions & 4 deletions numalogic/tools/aggregators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from collections.abc import Callable

import numpy as np
Expand All @@ -7,9 +6,6 @@
from numalogic.transforms import expmov_avg_aggregator


EXP_MOV_AVG_BETA = float(os.getenv("EXP_MOV_AVG_BETA", "0.6"))


def aggregate_window(
y: npt.NDArray[float], agg_func: Callable = expmov_avg_aggregator, **agg_func_kw
) -> npt.NDArray[float]:
Expand Down
8 changes: 6 additions & 2 deletions numalogic/udfs/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600"))
LOCAL_CACHE_SIZE = int(os.getenv("LOCAL_CACHE_SIZE", "10000"))
LOAD_LATEST = os.getenv("LOAD_LATEST", "false").lower() == "true"
EXP_MOV_AVG_BETA = float(os.getenv("EXP_MOV_AVG_BETA", "0.6"))

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -196,7 +195,12 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:

@staticmethod
def _per_feature_score(feat_names: list[str], scores: NDArray[float]) -> dict[str, float]:
if len(scores) != len(feat_names):
if (scores_len := len(scores)) != len(feat_names):
_LOGGER.debug(
"Scores length: %s does not match feat_names: %s",
scores_len,
feat_names,
)
return {}
return dict(zip(feat_names, scores))

Expand Down
1 change: 0 additions & 1 deletion tests/udfs/test_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def test_postprocess_all_model_present_02(self):

msg = udf(KEYS, Datum(keys=KEYS, value=orjson.dumps(data), **DATUM_KW))
payload = OutputPayload(**orjson.loads(msg[0].value))
print(payload)
self.assertFalse(list(payload.data))
self.assertEqual(1, len(msg))

Expand Down

0 comments on commit b41a7f7

Please sign in to comment.