Skip to content

Commit

Permalink
feat!: support full multivariate prometheus fetching (#325)
Browse files Browse the repository at this point in the history
- support multivariate fetching in Prom fetcher
- fix conf keys used for training and inference
- fix inf filling (inf values could have caused NaN output before or
failed training before)

---------

Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Nov 27, 2023
1 parent c967f20 commit 8b7f45f
Show file tree
Hide file tree
Showing 18 changed files with 960 additions and 135 deletions.
3 changes: 2 additions & 1 deletion numalogic/connectors/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ class ConnectorConf:

@dataclass
class PrometheusConf(ConnectorConf):
pushgateway: str
pushgateway: str = ""
scrape_interval: int = 30
return_labels: list[str] = field(default_factory=list)


@dataclass
Expand Down
34 changes: 22 additions & 12 deletions numalogic/connectors/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def build_query(metric: str, filters: dict[str, str]) -> str:

def fetch(
self,
metric_name: str,
start: datetime,
end: Optional[datetime] = None,
metric_name: str = "",
filters: Optional[dict[str, str]] = None,
return_labels: Optional[list[str]] = None,
aggregate: bool = True,
Expand All @@ -70,9 +70,9 @@ def fetch(
Args:
-------
metric_name: Prometheus metric name
start: Start time
end: End time
metric_name: Prometheus metric name (default="")
filters: Prometheus label filters
return_labels: Prometheus label names as columns to return
aggregate: Whether to aggregate the data
Expand All @@ -96,22 +96,32 @@ def fetch(
results = self.query_range(query, start_ts, end_ts)

df = pd.json_normalize(results)
return_labels = [f"metric.{label}" for label in return_labels or []]
if df.empty:
LOGGER.warning("Query returned no results")
return df

df = self._consolidate_df(df, metric_name, return_labels)
if aggregate and return_labels:
df = self._agg_df(df, [metric_name])
extra_labels = [f"metric.{label}" for label in return_labels or []]
if metric_name:
metric_names = [metric_name]
else:
metric_names = self._extract_metric_names(df)

try:
df.set_index("timestamp", inplace=True)
except KeyError:
pass
df.sort_values(by="timestamp", inplace=True)
df.set_index(_METRIC_KEY, inplace=True)

return df
dfs = []
for metric_name in metric_names:
_df = self._consolidate_df(df.loc[[metric_name]], metric_name, extra_labels)
dfs.append(_df.set_index(["timestamp", *extra_labels]))

df = dfs[0].join(dfs[1:]).reset_index().set_index("timestamp")

if return_labels:
df.rename(columns=dict(zip(extra_labels, return_labels)), inplace=True)

if aggregate:
df = self._agg_df(df, metric_names)

return df.sort_values(by=["timestamp"])

def raw_fetch(
self,
Expand Down
4 changes: 3 additions & 1 deletion numalogic/udfs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numalogic.udfs.inference import InferenceUDF
from numalogic.udfs.postprocess import PostprocessUDF
from numalogic.udfs.preprocess import PreprocessUDF
from numalogic.udfs.trainer import TrainerUDF
from numalogic.udfs.trainer import TrainerUDF, PromTrainerUDF, DruidTrainerUDF


def set_logger() -> None:
Expand All @@ -27,6 +27,8 @@ def set_logger() -> None:
"PreprocessUDF",
"InferenceUDF",
"TrainerUDF",
"PromTrainerUDF",
"DruidTrainerUDF",
"PostprocessUDF",
"UDFFactory",
"StreamConf",
Expand Down
25 changes: 17 additions & 8 deletions numalogic/udfs/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,28 @@
# limitations under the License.

import logging
import warnings
from typing import ClassVar

from pynumaflow.mapper import Mapper, MultiProcMapper, AsyncMapper

from numalogic.udfs import NumalogicUDF
from numalogic.udfs.inference import InferenceUDF
from numalogic.udfs.postprocess import PostprocessUDF
from numalogic.udfs.preprocess import PreprocessUDF
from numalogic.udfs.trainer import TrainerUDF

_LOGGER = logging.getLogger(__name__)


class UDFFactory:
"""Factory class to fetch the right UDF."""

from numalogic.udfs import NumalogicUDF
from numalogic.udfs.inference import InferenceUDF
from numalogic.udfs.postprocess import PostprocessUDF
from numalogic.udfs.preprocess import PreprocessUDF
from numalogic.udfs.trainer import DruidTrainerUDF, PromTrainerUDF

_UDF_MAP: ClassVar[dict[str, type[NumalogicUDF]]] = {
"preprocess": PreprocessUDF,
"inference": InferenceUDF,
"postprocess": PostprocessUDF,
"trainer": TrainerUDF,
"druidtrainer": DruidTrainerUDF,
"promtrainer": PromTrainerUDF,
}

@classmethod
Expand All @@ -50,6 +51,12 @@ def get_udf_cls(cls, udf_name: str) -> type[NumalogicUDF]:
------
ValueError: If the UDF name is invalid
"""
if udf_name == "trainer":
warnings.warn(
"UDF name 'trainer' is deprecated. Use 'druidtrainer' or 'promtrainer' instead."
)
udf_name = "druidtrainer"

try:
return cls._UDF_MAP[udf_name]
except KeyError as err:
Expand Down Expand Up @@ -81,6 +88,8 @@ def get_udf_instance(cls, udf_name: str, **kwargs) -> NumalogicUDF:
class ServerFactory:
"""Factory class to fetch the right pynumaflow function server/mapper."""

from pynumaflow.mapper import Mapper, MultiProcMapper, AsyncMapper

_SERVER_MAP: ClassVar[dict] = {
"sync": Mapper,
"async": AsyncMapper,
Expand Down
7 changes: 5 additions & 2 deletions numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,12 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
# Forward payload if a training request is tagged
if payload.header == Header.TRAIN_REQUEST:
return Messages(Message(keys=keys, value=payload.to_json()))

_conf = self.get_conf(payload.config_id)

artifact_data, payload = _load_artifact(
skeys=keys,
dkeys=[self.get_conf(payload.config_id).numalogic_conf.model.name],
skeys=[_ckey for _, _ckey in zip(_conf.composite_keys, payload.composite_keys)],
dkeys=[_conf.numalogic_conf.model.name],
payload=payload,
model_registry=self.model_registry,
load_latest=LOAD_LATEST,
Expand Down
11 changes: 7 additions & 4 deletions numalogic/udfs/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,13 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload = StreamPayload(**orjson.loads(datum.value))

# load configs
thresh_cfg = self.get_conf(payload.config_id).numalogic_conf.threshold
postprocess_cfg = self.get_conf(payload.config_id).numalogic_conf.postprocess
_conf = self.get_conf(payload.config_id)
thresh_cfg = _conf.numalogic_conf.threshold
postprocess_cfg = _conf.numalogic_conf.postprocess

# load artifact
thresh_artifact, payload = _load_artifact(
skeys=keys,
skeys=[_ckey for _, _ckey in zip(_conf.composite_keys, payload.composite_keys)],
dkeys=[thresh_cfg.name],
payload=payload,
model_registry=self.model_registry,
Expand Down Expand Up @@ -136,9 +137,11 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:

# Forward payload if a training request is tagged
if payload.header == Header.TRAIN_REQUEST or payload.status == Status.ARTIFACT_STALE:
_conf = self.get_conf(payload.config_id)
ckeys = [_ckey for _, _ckey in zip(_conf.composite_keys, payload.composite_keys)]
train_payload = TrainerPayload(
uuid=payload.uuid,
composite_keys=keys,
composite_keys=ckeys,
metrics=payload.metrics,
config_id=payload.config_id,
)
Expand Down
21 changes: 9 additions & 12 deletions numalogic/udfs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional

import orjson
from numpy._typing import NDArray
from numpy.typing import NDArray
from pynumaflow.mapper import Datum, Messages, Message
from sklearn.pipeline import make_pipeline

Expand Down Expand Up @@ -97,15 +97,12 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload = make_stream_payload(data_payload, raw_df, timestamps, keys)

# Check if model will be present in registry
if any(
[_conf.stateful for _conf in self.get_conf(payload.config_id).numalogic_conf.preprocess]
):

_conf = self.get_conf(payload.config_id)
if any(_cfg.stateful for _cfg in _conf.numalogic_conf.preprocess):
preproc_artifact, payload = _load_artifact(
skeys=keys,
dkeys=[
_conf.name
for _conf in self.get_conf(payload.config_id).numalogic_conf.preprocess
],
skeys=[_ckey for _, _ckey in zip(_conf.composite_keys, payload.composite_keys)],
dkeys=[_cfg.name for _cfg in _conf.numalogic_conf.preprocess],
payload=payload,
model_registry=self.model_registry,
load_latest=LOAD_LATEST,
Expand All @@ -123,13 +120,13 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
payload, status=Status.ARTIFACT_NOT_FOUND, header=Header.TRAIN_REQUEST
)
return Messages(Message(keys=keys, value=payload.to_json()))

# Model will not be in registry
else:
# Load configuration for the config_id
_LOGGER.info("%s - Initializing model from config: %s", payload.uuid, payload)
preproc_clf = self._load_model_from_config(
self.get_conf(payload.config_id).numalogic_conf.preprocess
)
preproc_clf = self._load_model_from_config(_conf.numalogic_conf.preprocess)

try:
x_scaled = self.compute(model=preproc_clf, input_=payload.get_data())
payload = replace(
Expand Down
5 changes: 5 additions & 0 deletions numalogic/udfs/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from numalogic.udfs.trainer._base import TrainerUDF
from numalogic.udfs.trainer._prom import PromTrainerUDF
from numalogic.udfs.trainer._druid import DruidTrainerUDF

__all__ = ["TrainerUDF", "PromTrainerUDF", "DruidTrainerUDF"]
84 changes: 5 additions & 79 deletions numalogic/udfs/trainer.py → numalogic/udfs/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from numalogic.base import StatelessTransformer
from numalogic.config import PreprocessFactory, ModelFactory, ThresholdFactory, RegistryFactory
from numalogic.config._config import NumalogicConf
from numalogic.config.factory import ConnectorFactory
from numalogic.connectors import DruidFetcherConf
from numalogic.models.autoencoder import TimeseriesTrainer
from numalogic.tools.data import StreamingDataset
from numalogic.tools.exceptions import ConfigNotFoundError, RedisRegistryError
Expand Down Expand Up @@ -55,52 +53,12 @@ def __init__(
jitter_sec=jitter_sec,
jitter_steps_sec=jitter_steps_sec,
)
self.druid_conf = self.pl_conf.druid_conf

data_fetcher_cls = ConnectorFactory.get_cls("DruidFetcher")
try:
self.data_fetcher = data_fetcher_cls(
url=self.druid_conf.url, endpoint=self.druid_conf.endpoint
)
except AttributeError:
_LOGGER.warning("Druid config not found, data fetcher will not be initialized!")
self.data_fetcher = None

self._model_factory = ModelFactory()
self._preproc_factory = PreprocessFactory()
self._thresh_factory = ThresholdFactory()
self.train_msg_deduplicator = TrainMsgDeduplicator(r_client)

def register_druid_fetcher_conf(self, config_id: str, conf: DruidFetcherConf) -> None:
"""
Register DruidFetcherConf with the UDF.
Args:
config_id: Config ID
conf: DruidFetcherConf object
"""
self.pl_conf.druid_conf.id_fetcher[config_id] = conf

def get_druid_fetcher_conf(self, config_id: str) -> DruidFetcherConf:
"""
Get DruidFetcherConf with the given ID.
Args:
config_id: Config ID
Returns
-------
DruidFetcherConf object
Raises
------
ConfigNotFoundError: If config with the given ID is not found
"""
try:
return self.pl_conf.druid_conf.id_fetcher[config_id]
except KeyError as err:
raise ConfigNotFoundError(f"Config with ID {config_id} not found!") from err

@classmethod
def compute(
cls,
Expand Down Expand Up @@ -220,7 +178,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
numalogic_cfg=_conf.numalogic_conf,
)

# Save artifacts`
# Save artifacts
skeys = payload.composite_keys

self.artifacts_to_save(
Expand Down Expand Up @@ -298,6 +256,7 @@ def _is_data_sufficient(self, payload: TrainerPayload, df: pd.DataFrame) -> bool
return False
return True

# TODO: Use a custom imputer in transforms module
@staticmethod
def get_feature_arr(
raw_df: pd.DataFrame, metrics: list[str], fill_value: float = 0.0
Expand All @@ -307,12 +266,12 @@ def get_feature_arr(
if col not in raw_df.columns:
raw_df[col] = fill_value
feat_df = raw_df[metrics]
feat_df = feat_df.fillna(fill_value)
feat_df = feat_df.fillna(fill_value).replace([np.inf, -np.inf], fill_value)
return feat_df.to_numpy(dtype=np.float32)

def fetch_data(self, payload: TrainerPayload) -> pd.DataFrame:
"""
Fetch data from druid.
Fetch data from a data connector.
Args:
payload: TrainerPayload object
Expand All @@ -321,37 +280,4 @@ def fetch_data(self, payload: TrainerPayload) -> pd.DataFrame:
-------
Dataframe
"""
_start_time = time.perf_counter()
_conf = self.get_conf(payload.config_id)
_fetcher_conf = self.druid_conf.fetcher or (
self.get_druid_fetcher_conf(payload.config_id) if self.druid_conf.id_fetcher else None
)
if not _fetcher_conf:
raise ConfigNotFoundError(
f"Druid fetcher config not found for config_id: {payload.config_id}!"
)

try:
_df = self.data_fetcher.fetch(
datasource=_fetcher_conf.datasource,
filter_keys=_conf.composite_keys,
filter_values=payload.composite_keys,
dimensions=list(_fetcher_conf.dimensions),
delay=self.druid_conf.delay_hrs,
granularity=_fetcher_conf.granularity,
aggregations=dict(_fetcher_conf.aggregations),
group_by=list(_fetcher_conf.group_by),
pivot=_fetcher_conf.pivot,
hours=_conf.numalogic_conf.trainer.train_hours,
)
except Exception:
_LOGGER.exception("%s - Error while fetching data from druid", payload.uuid)
return pd.DataFrame()

_LOGGER.debug(
"%s - Time taken to fetch data: %.3f sec, df shape: %s",
payload.uuid,
time.perf_counter() - _start_time,
_df.shape,
)
return _df
raise NotImplementedError("fetch_data method not implemented")
Loading

0 comments on commit 8b7f45f

Please sign in to comment.