Skip to content

Commit

Permalink
add: druidfetcher support for different configId (#307)
Browse files Browse the repository at this point in the history
Update config to support multiple DruidFetcher

---------

Signed-off-by: s0nicboOm <[email protected]>
  • Loading branch information
s0nicboOm committed Oct 4, 2023
1 parent f25f49a commit 64c2e95
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 47 deletions.
9 changes: 6 additions & 3 deletions numalogic/connectors/_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass, field
from enum import IntEnum

from omegaconf import MISSING
from typing import Optional


class ConnectorType(IntEnum):
Expand Down Expand Up @@ -60,8 +59,12 @@ class DruidConf(ConnectorConf):
endpoint: Druid endpoint
delay_hrs: Delay in hours for fetching data from Druid
fetcher: DruidFetcherConf
id_fetcher: dict of DruidFetcherConf for fetching ids
Note: Either one of the fetcher or id_fetcher should be provided.
"""

endpoint: str
delay_hrs: float = 3.0
fetcher: DruidFetcherConf = MISSING
fetcher: Optional[DruidFetcherConf] = None
id_fetcher: Optional[dict[str, DruidFetcherConf]] = None
4 changes: 2 additions & 2 deletions numalogic/tools/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from numalogic.base import TorchModel, BaseThresholdModel, BaseTransformer

try:
from redis.client import AbstractRedis
from redis.client import Redis
except ImportError:
redis_client_t = TypeVar("redis_client_t")
else:
redis_client_t = TypeVar("redis_client_t", bound=AbstractRedis, covariant=True)
redis_client_t = TypeVar("redis_client_t", bound=Redis, covariant=True)

artifact_t = TypeVar(
"artifact_t",
Expand Down
50 changes: 44 additions & 6 deletions numalogic/udfs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from numalogic.config import PreprocessFactory, ModelFactory, ThresholdFactory, RegistryFactory
from numalogic.config._config import NumalogicConf
from numalogic.config.factory import ConnectorFactory
from numalogic.connectors import DruidFetcherConf
from numalogic.models.autoencoder import AutoencoderTrainer
from numalogic.tools.data import StreamingDataset
from numalogic.tools.exceptions import ConfigNotFoundError, RedisRegistryError
Expand Down Expand Up @@ -101,6 +102,36 @@ def get_conf(self, config_id: str) -> StreamConf:
except KeyError as err:
raise ConfigNotFoundError(f"Config with ID {config_id} not found!") from err

def register_druid_fetcher_conf(self, config_id: str, conf: DruidFetcherConf) -> None:
"""
Register DruidFetcherConf with the UDF.
Args:
config_id: Config ID
conf: DruidFetcherConf object
"""
self.pl_conf.druid_conf.id_fetcher[config_id] = conf

def get_druid_fetcher_conf(self, config_id: str) -> DruidFetcherConf:
"""
Get DruidFetcherConf with the given ID.
Args:
config_id: Config ID
Returns
-------
DruidFetcherConf object
Raises
------
ConfigNotFoundError: If config with the given ID is not found
"""
try:
return self.pl_conf.druid_conf.id_fetcher[config_id]
except KeyError as err:
raise ConfigNotFoundError(f"Config with ID {config_id} not found!") from err

@classmethod
def compute(
cls,
Expand Down Expand Up @@ -316,18 +347,25 @@ def fetch_data(self, payload: TrainerPayload) -> pd.DataFrame:
"""
_start_time = time.perf_counter()
_conf = self.get_conf(payload.config_id)
_fetcher_conf = self.druid_conf.fetcher or (
self.get_druid_fetcher_conf(payload.config_id) if self.druid_conf.id_fetcher else None
)
if not _fetcher_conf:
raise ConfigNotFoundError(
f"Druid fetcher config not found for config_id: {payload.config_id}!"
)

try:
_df = self.data_fetcher.fetch(
datasource=self.druid_conf.fetcher.datasource,
datasource=_fetcher_conf.datasource,
filter_keys=_conf.composite_keys,
filter_values=payload.composite_keys,
dimensions=list(self.druid_conf.fetcher.dimensions),
dimensions=list(_fetcher_conf.dimensions),
delay=self.druid_conf.delay_hrs,
granularity=self.druid_conf.fetcher.granularity,
aggregations=dict(self.druid_conf.fetcher.aggregations),
group_by=list(self.druid_conf.fetcher.group_by),
pivot=self.druid_conf.fetcher.pivot,
granularity=_fetcher_conf.granularity,
aggregations=dict(_fetcher_conf.aggregations),
group_by=list(_fetcher_conf.group_by),
pivot=_fetcher_conf.pivot,
hours=_conf.numalogic_conf.trainer.train_hours,
)
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.6.0a10"
version = "0.6.0a11"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
13 changes: 7 additions & 6 deletions tests/udfs/resources/_config2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ redis_conf:
druid_conf:
url: "druid-endpoint"
endpoint: "endpoint"
fetcher:
dimensions: [ "col1" ]
datasource: "table-name"
group_by: [ "timestamp", "col1" ]
pivot:
columns: [ "col2" ]
id_fetcher:
druid-config:
dimensions: [ "col1" ]
datasource: "table-name"
group_by: [ "timestamp", "col1" ]
pivot:
columns: [ "col2" ]
Loading

0 comments on commit 64c2e95

Please sign in to comment.