Skip to content

Commit

Permalink
feat: fallback to stddev if threshold is too low
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Apr 23, 2024
1 parent 5806a68 commit 9ad6d8d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
18 changes: 17 additions & 1 deletion numalogic/models/threshold/_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from numalogic.base import BaseThresholdModel
from numalogic.tools.exceptions import InvalidDataShapeError, ModelInitializationError

import logging

LOGGER = logging.getLogger(__name__)
_INLIER: Final[int] = 0
_OUTLIER: Final[int] = 1
_INPUT_DIMS: Final[int] = 2
Expand All @@ -19,18 +22,20 @@ class MaxPercentileThreshold(BaseThresholdModel):
min_threshold: Value to be used if threshold is less than this
"""

__slots__ = ("_max_percentile", "_min_thresh", "_thresh", "_is_fitted")
__slots__ = ("_max_percentile", "_min_thresh", "_thresh", "_is_fitted", "_adjust_threshold")

def __init__(
self,
max_inlier_percentile: float = 96.0,
min_threshold: float = 1e-4,
adjust_threshold: bool = False,
):
super().__init__()
self._max_percentile = max_inlier_percentile
self._min_thresh = min_threshold
self._thresh = None
self._is_fitted = False
self._adjust_threshold = adjust_threshold

@property
def threshold(self):
Expand All @@ -45,6 +50,17 @@ def _validate_input(x: npt.NDArray[float]) -> None:
def fit(self, x: npt.NDArray[float]) -> Self:
self._validate_input(x)
self._thresh = np.percentile(x, self._max_percentile, axis=0)

if self._adjust_threshold:
for idx, _ in enumerate(self._thresh):
if self._thresh[idx] / self._min_thresh < 1e-2:
LOGGER.info(

Check warning on line 57 in numalogic/models/threshold/_median.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/threshold/_median.py#L57

Added line #L57 was not covered by tests
"Min threshold is less than 1e-2 times the "
"threshold for column %s; Using mean instead.",
idx,
)
self._thresh[idx] = np.mean(x[:, idx]) + (3 * np.std(x[:, idx]))

Check warning on line 62 in numalogic/models/threshold/_median.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/threshold/_median.py#L62

Added line #L62 was not covered by tests

self._thresh[self._thresh < self._min_thresh] = self._min_thresh
self._is_fitted = True
return self
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.9.1a3"
version = "0.9.1a4"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down

0 comments on commit 9ad6d8d

Please sign in to comment.