Skip to content

Commit

Permalink
feat!: separate Prom trainer and Druid trainer
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Nov 20, 2023
1 parent 7be6ddc commit 8369510
Show file tree
Hide file tree
Showing 13 changed files with 871 additions and 102 deletions.
3 changes: 2 additions & 1 deletion numalogic/connectors/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ class ConnectorConf:

@dataclass
class PrometheusConf(ConnectorConf):
pushgateway: str
pushgateway: str = ""
scrape_interval: int = 30
return_labels: list[str] = field(default_factory=list)


@dataclass
Expand Down
4 changes: 3 additions & 1 deletion numalogic/udfs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numalogic.udfs.inference import InferenceUDF
from numalogic.udfs.postprocess import PostprocessUDF
from numalogic.udfs.preprocess import PreprocessUDF
from numalogic.udfs.trainer import TrainerUDF
from numalogic.udfs.trainer import TrainerUDF, PromTrainerUDF, DruidTrainerUDF


def set_logger() -> None:
Expand All @@ -27,6 +27,8 @@ def set_logger() -> None:
"PreprocessUDF",
"InferenceUDF",
"TrainerUDF",
"PromTrainerUDF",
"DruidTrainerUDF",
"PostprocessUDF",
"UDFFactory",
"StreamConf",
Expand Down
25 changes: 17 additions & 8 deletions numalogic/udfs/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,28 @@
# limitations under the License.

import logging
import warnings
from typing import ClassVar

from pynumaflow.mapper import Mapper, MultiProcMapper, AsyncMapper

from numalogic.udfs import NumalogicUDF
from numalogic.udfs.inference import InferenceUDF
from numalogic.udfs.postprocess import PostprocessUDF
from numalogic.udfs.preprocess import PreprocessUDF
from numalogic.udfs.trainer import TrainerUDF

_LOGGER = logging.getLogger(__name__)


class UDFFactory:
"""Factory class to fetch the right UDF."""

from numalogic.udfs import NumalogicUDF
from numalogic.udfs.inference import InferenceUDF
from numalogic.udfs.postprocess import PostprocessUDF
from numalogic.udfs.preprocess import PreprocessUDF
from numalogic.udfs.trainer import DruidTrainerUDF, PromTrainerUDF

_UDF_MAP: ClassVar[dict[str, type[NumalogicUDF]]] = {
"preprocess": PreprocessUDF,
"inference": InferenceUDF,
"postprocess": PostprocessUDF,
"trainer": TrainerUDF,
"druidtrainer": DruidTrainerUDF,
"promtrainer": PromTrainerUDF,
}

@classmethod
Expand All @@ -50,6 +51,12 @@ def get_udf_cls(cls, udf_name: str) -> type[NumalogicUDF]:
------
ValueError: If the UDF name is invalid
"""
if udf_name == "trainer":
warnings.warn(
"UDF name 'trainer' is deprecated. Use 'druidtrainer' or 'promtrainer' instead."
)
udf_name = "druidtrainer"

try:
return cls._UDF_MAP[udf_name]
except KeyError as err:
Expand Down Expand Up @@ -81,6 +88,8 @@ def get_udf_instance(cls, udf_name: str, **kwargs) -> NumalogicUDF:
class ServerFactory:
"""Factory class to fetch the right pynumaflow function server/mapper."""

from pynumaflow.mapper import Mapper, MultiProcMapper, AsyncMapper

_SERVER_MAP: ClassVar[dict] = {
"sync": Mapper,
"async": AsyncMapper,
Expand Down
5 changes: 5 additions & 0 deletions numalogic/udfs/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from numalogic.udfs.trainer._base import TrainerUDF
from numalogic.udfs.trainer._prom import PromTrainerUDF
from numalogic.udfs.trainer._druid import DruidTrainerUDF

__all__ = ["TrainerUDF", "PromTrainerUDF", "DruidTrainerUDF"]
79 changes: 2 additions & 77 deletions numalogic/udfs/trainer.py → numalogic/udfs/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from numalogic.base import StatelessTransformer
from numalogic.config import PreprocessFactory, ModelFactory, ThresholdFactory, RegistryFactory
from numalogic.config._config import NumalogicConf
from numalogic.config.factory import ConnectorFactory
from numalogic.connectors import DruidFetcherConf
from numalogic.models.autoencoder import TimeseriesTrainer
from numalogic.tools.data import StreamingDataset
from numalogic.tools.exceptions import ConfigNotFoundError, RedisRegistryError
Expand Down Expand Up @@ -55,52 +53,12 @@ def __init__(
jitter_sec=jitter_sec,
jitter_steps_sec=jitter_steps_sec,
)
self.druid_conf = self.pl_conf.druid_conf

data_fetcher_cls = ConnectorFactory.get_cls("DruidFetcher")
try:
self.data_fetcher = data_fetcher_cls(
url=self.druid_conf.url, endpoint=self.druid_conf.endpoint
)
except AttributeError:
_LOGGER.warning("Druid config not found, data fetcher will not be initialized!")
self.data_fetcher = None

self._model_factory = ModelFactory()
self._preproc_factory = PreprocessFactory()
self._thresh_factory = ThresholdFactory()
self.train_msg_deduplicator = TrainMsgDeduplicator(r_client)

def register_druid_fetcher_conf(self, config_id: str, conf: DruidFetcherConf) -> None:
"""
Register DruidFetcherConf with the UDF.
Args:
config_id: Config ID
conf: DruidFetcherConf object
"""
self.pl_conf.druid_conf.id_fetcher[config_id] = conf

def get_druid_fetcher_conf(self, config_id: str) -> DruidFetcherConf:
"""
Get DruidFetcherConf with the given ID.
Args:
config_id: Config ID
Returns
-------
DruidFetcherConf object
Raises
------
ConfigNotFoundError: If config with the given ID is not found
"""
try:
return self.pl_conf.druid_conf.id_fetcher[config_id]
except KeyError as err:
raise ConfigNotFoundError(f"Config with ID {config_id} not found!") from err

@classmethod
def compute(
cls,
Expand Down Expand Up @@ -312,7 +270,7 @@ def get_feature_arr(

def fetch_data(self, payload: TrainerPayload) -> pd.DataFrame:
"""
Fetch data from druid.
Fetch data from a data connector.
Args:
payload: TrainerPayload object
Expand All @@ -321,37 +279,4 @@ def fetch_data(self, payload: TrainerPayload) -> pd.DataFrame:
-------
Dataframe
"""
_start_time = time.perf_counter()
_conf = self.get_conf(payload.config_id)
_fetcher_conf = self.druid_conf.fetcher or (
self.get_druid_fetcher_conf(payload.config_id) if self.druid_conf.id_fetcher else None
)
if not _fetcher_conf:
raise ConfigNotFoundError(
f"Druid fetcher config not found for config_id: {payload.config_id}!"
)

try:
_df = self.data_fetcher.fetch(
datasource=_fetcher_conf.datasource,
filter_keys=_conf.composite_keys,
filter_values=payload.composite_keys,
dimensions=list(_fetcher_conf.dimensions),
delay=self.druid_conf.delay_hrs,
granularity=_fetcher_conf.granularity,
aggregations=dict(_fetcher_conf.aggregations),
group_by=list(_fetcher_conf.group_by),
pivot=_fetcher_conf.pivot,
hours=_conf.numalogic_conf.trainer.train_hours,
)
except Exception:
_LOGGER.exception("%s - Error while fetching data from druid", payload.uuid)
return pd.DataFrame()

_LOGGER.debug(
"%s - Time taken to fetch data: %.3f sec, df shape: %s",
payload.uuid,
time.perf_counter() - _start_time,
_df.shape,
)
return _df
raise NotImplementedError("fetch_data method not implemented")
118 changes: 118 additions & 0 deletions numalogic/udfs/trainer/_druid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import logging
import time
from typing import Optional

import pandas as pd

from numalogic.config.factory import ConnectorFactory
from numalogic.connectors import DruidFetcherConf
from numalogic.tools.exceptions import ConfigNotFoundError
from numalogic.tools.types import redis_client_t
from numalogic.udfs._config import PipelineConf
from numalogic.udfs.entities import TrainerPayload
from numalogic.udfs.trainer._base import TrainerUDF

_LOGGER = logging.getLogger(__name__)


class DruidTrainerUDF(TrainerUDF):
"""
Trainer UDF using Druid as data source.
Args:
r_client: Redis client
pl_conf: Pipeline config
"""

def __init__(
self,
r_client: redis_client_t,
pl_conf: Optional[PipelineConf] = None,
):
super().__init__(r_client=r_client, pl_conf=pl_conf)
self.dataconn_conf = self.pl_conf.druid_conf
data_fetcher_cls = ConnectorFactory.get_cls("DruidFetcher")
try:
self.data_fetcher = data_fetcher_cls(
url=self.dataconn_conf.url, endpoint=self.dataconn_conf.endpoint
)
except AttributeError as err:
raise ConfigNotFoundError("Druid config not found!") from err

def register_druid_fetcher_conf(self, config_id: str, conf: DruidFetcherConf) -> None:
"""
Register DruidFetcherConf with the UDF.
Args:
config_id: Config ID
conf: DruidFetcherConf object
"""
self.pl_conf.druid_conf.id_fetcher[config_id] = conf

def get_druid_fetcher_conf(self, config_id: str) -> DruidFetcherConf:
"""
Get DruidFetcherConf with the given ID.
Args:
config_id: Config ID
Returns
-------
DruidFetcherConf object
Raises
------
ConfigNotFoundError: If config with the given ID is not found
"""
try:
return self.pl_conf.druid_conf.id_fetcher[config_id]
except KeyError as err:
raise ConfigNotFoundError(f"Config with ID {config_id} not found!") from err

def fetch_data(self, payload: TrainerPayload) -> pd.DataFrame:
"""
Fetch data from druid.
Args:
payload: TrainerPayload object
Returns
-------
Dataframe
"""
_start_time = time.perf_counter()
_conf = self.get_conf(payload.config_id)
_fetcher_conf = self.dataconn_conf.fetcher or (
self.get_druid_fetcher_conf(payload.config_id)
if self.dataconn_conf.id_fetcher
else None
)
if not _fetcher_conf:
raise ConfigNotFoundError(
f"Druid fetcher config not found for config_id: {payload.config_id}!"
)

try:
_df = self.data_fetcher.fetch(
datasource=_fetcher_conf.datasource,
filter_keys=_conf.composite_keys,
filter_values=payload.composite_keys,
dimensions=list(_fetcher_conf.dimensions),
delay=self.dataconn_conf.delay_hrs,
granularity=_fetcher_conf.granularity,
aggregations=dict(_fetcher_conf.aggregations),
group_by=list(_fetcher_conf.group_by),
pivot=_fetcher_conf.pivot,
hours=_conf.numalogic_conf.trainer.train_hours,
)
except Exception:
_LOGGER.exception("%s - Error while fetching data from druid", payload.uuid)
return pd.DataFrame()

_LOGGER.debug(
"%s - Time taken to fetch data: %.3f sec, df shape: %s",
payload.uuid,
time.perf_counter() - _start_time,
_df.shape,
)
return _df
Loading

0 comments on commit 8369510

Please sign in to comment.