Skip to content

Commit

Permalink
Adding RDS Trainer UDF changes (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
SaiSharathReddy committed May 7, 2024
1 parent e7e879c commit f29f771
Show file tree
Hide file tree
Showing 29 changed files with 15,703 additions and 54 deletions.
2 changes: 1 addition & 1 deletion numalogic/config/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get_cls(cls, name: str):
class ConnectorFactory(_ObjectFactory):
"""Factory class for data connectors."""

_CLS_SET: ClassVar[frozenset] = frozenset({"PrometheusFetcher", "DruidFetcher"})
_CLS_SET: ClassVar[frozenset] = frozenset({"PrometheusFetcher", "DruidFetcher", "RDSFetcher"})

@classmethod
def get_cls(cls, name: str):
Expand Down
7 changes: 6 additions & 1 deletion numalogic/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
DruidConf,
DruidFetcherConf,
ConnectorType,
RDSConf,
RDSFetcherConf,
)
from numalogic.connectors.rds import RDSFetcher
from numalogic.connectors.prometheus import PrometheusFetcher

__all__ = [
Expand All @@ -18,9 +21,11 @@
"DruidFetcherConf",
"ConnectorType",
"PrometheusFetcher",
"RDSFetcher",
"RDSConf",
"RDSFetcherConf",
]


if find_spec("pydruid"):
from numalogic.connectors.druid import DruidFetcher # noqa: F401

Expand Down
67 changes: 67 additions & 0 deletions numalogic/connectors/_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Optional
from numalogic.connectors.utils.aws.config import RDSConnectionConfig
from numalogic.connectors.exceptions import RDSFetcherConfValidationException


class ConnectorType(IntEnum):
Expand Down Expand Up @@ -52,6 +54,49 @@ def __post_init__(self):
self.aggregations = {"count": doublesum("count")}


@dataclass
class RDSFetcherConf:
"""
RDSFetcherConf class represents the configuration for fetching data from an RDS data source.
Args:
datasource (str): The name of the data source.
dimensions (list[str]): A list of dimension column names.
group_by (list[str]): A list of column names to group the data by.
pivot (Pivot): An instance of the Pivot class representing the pivot configuration.
hash_query_type (bool): A boolean indicating whether to use hash query type.
hash_column_name (Optional[str]): The name of the hash column. (default: None)
datetime_column_name (str): The name of the datetime column. (default: "eventdatetime")
metrics (list[str]): A list of metric column names.
Methods
-------
__post_init__(): Performs post-initialization validation checks.
Raises
------
RDSFetcherConfValidationException: If the hash_query_type is enabled
but hash_column_name is not provided.
"""

datasource: str
dimensions: list[str]
# metric column names
metrics: list[str]
group_by: list[str] = field(default_factory=list)
pivot: Pivot = field(default_factory=lambda: Pivot())
hash_query_type: bool = True
hash_column_name: str = "model_md5_hash"
datetime_column_name: str = "eventdatetime"

def __post_init__(self):
if self.hash_query_type:
if self.hash_column_name.strip() == "":
raise RDSFetcherConfValidationException(
"when hash_query_type is enabled, hash_column_name is required property "
)


@dataclass
class DruidConf(ConnectorConf):
"""
Expand All @@ -70,3 +115,25 @@ class DruidConf(ConnectorConf):
delay_hrs: float = 3.0
fetcher: Optional[DruidFetcherConf] = None
id_fetcher: Optional[dict[str, DruidFetcherConf]] = None


@dataclass
class RDSConf:
"""
Class representing the configuration for fetching data from an RDS data source.
Args:
connection_conf (RDSConnectionConfig): An instance of the RDSConnectionConfig class
representing the connection configuration.
delay_hrs (float): The delay in hours for fetching data. Defaults to 3.0.
fetcher (Optional[RDSFetcherConf]): An optional instance of the RDSFetcherConf class
representing the fetcher configuration. Defaults to None.
id_fetcher (Optional[dict[str, RDSFetcherConf]]): An optional dictionary mapping IDs to
instances of the RDSFetcherConf class representing the fetcher configuration.
Defaults to None.
"""

connection_conf: RDSConnectionConfig
delay_hrs: float = 3.0
fetcher: Optional[RDSFetcherConf] = None
id_fetcher: Optional[dict[str, RDSFetcherConf]] = None
10 changes: 10 additions & 0 deletions numalogic/connectors/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class ConnectorFetcherException(Exception):
"""Custom exception class for grouping all Connector Exceptions together."""

pass


class RDSFetcherConfValidationException(ConnectorFetcherException):
"""A custom exception class for handling validation errors in RDSFetcherConf."""

pass
14 changes: 7 additions & 7 deletions numalogic/connectors/rds/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABCMeta, abstractmethod
from typing import Optional
import pandas as pd
from numalogic.connectors.utils.aws.config import DatabaseServiceProvider, RDSConfig
from numalogic.connectors.utils.aws.config import DatabaseServiceProvider, RDSConnectionConfig
from numalogic.connectors.utils.aws.boto3_client_manager import Boto3ClientManager
import logging
from numalogic.connectors._config import Pivot
Expand All @@ -13,7 +13,7 @@
def format_dataframe(
df: pd.DataFrame,
query: str,
datetime_field_name: str,
datetime_column_name: str,
group_by: Optional[list[str]] = None,
pivot: Optional[Pivot] = None,
) -> pd.DataFrame:
Expand All @@ -26,7 +26,7 @@ def format_dataframe(
The input DataFrame to be formatted.
query : str
The SQL query used to retrieve the data.
datetime_field_name : str
datetime_column_name : str
The name of the datetime field in the DataFrame.
group_by : Optional[list[str]], optional
A list of column names to group the DataFrame by, by default None.
Expand All @@ -40,8 +40,8 @@ def format_dataframe(
"""
_start_time = time.perf_counter()
df["timestamp"] = pd.to_datetime(df[datetime_field_name]).astype("int64") // 10**6
df.drop(columns=datetime_field_name, inplace=True)
df["timestamp"] = pd.to_datetime(df[datetime_column_name]).astype("int64") // 10**6
df.drop(columns=datetime_column_name, inplace=True)
if group_by:
df = df.groupby(by=group_by).sum().reset_index()

Expand All @@ -65,12 +65,12 @@ class represents a data fetcher for RDS (Relational Database Service) connection
connection, and executing queries.
Args:
- db_config (RDSConfig): The configuration object for the RDS connection.
- db_config (RDSConnectionConfig): The configuration object for the RDS connection.
- kwargs (dict): Additional keyword arguments.
"""

def __init__(self, db_config: RDSConfig, **kwargs):
def __init__(self, db_config: RDSConnectionConfig, **kwargs):
self.kwargs = kwargs
self.db_config = db_config
self.connection = None
Expand Down
16 changes: 10 additions & 6 deletions numalogic/connectors/rds/_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from numalogic.connectors._base import DataFetcher
from numalogic.connectors._config import Pivot
from numalogic.connectors.rds._base import format_dataframe
from numalogic.connectors.utils.aws.config import RDSConfig
from numalogic.connectors.utils.aws.config import RDSConnectionConfig
import logging
import pandas as pd
from numalogic.connectors.rds.db.factory import RdsFactory
Expand All @@ -19,12 +19,12 @@ class is a subclass of DataFetcher and ABC (Abstract Base Class).
Attributes
----------
db_config (RDSConfig): The configuration object for the RDS instance.
db_config (RDSConnectionConfig): The configuration object for the RDS instance.
fetcher (db.CLASS_TYPE): The fetcher object for the specific database type.
"""

def __init__(self, db_config: RDSConfig):
def __init__(self, db_config: RDSConnectionConfig):
super().__init__(db_config.endpoint)
self.db_config = db_config
factory_object = RdsFactory()
Expand All @@ -34,7 +34,7 @@ def __init__(self, db_config: RDSConfig):
def fetch(
self,
query,
datetime_field_name: str,
datetime_column_name: str,
pivot: Optional[Pivot] = None,
group_by: Optional[list[str]] = None,
) -> pd.DataFrame:
Expand All @@ -43,7 +43,7 @@ def fetch(
Args:
query (str): The SQL query to be executed.
datetime_field_name (str): The name of the datetime field in the fetched data.
datetime_column_name (str): The name of the datetime field in the fetched data.
pivot (Optional[Pivot], optional): The pivot configuration for the fetched data.
Defaults to None.
group_by (Optional[list[str]], optional): The list of fields to group the
Expand All @@ -60,7 +60,11 @@ def fetch(
return pd.DataFrame()

formatted_df = format_dataframe(
df, query=query, datetime_field_name=datetime_field_name, pivot=pivot, group_by=group_by
df,
query=query,
datetime_column_name=datetime_column_name,
pivot=pivot,
group_by=group_by,
)
_end_time = time.perf_counter() - _start_time
_LOGGER.info("RDS Query: %s Fetch Time: %.4fs", query, _end_time)
Expand Down
12 changes: 6 additions & 6 deletions numalogic/connectors/rds/db/mysql_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import logging

from numalogic.connectors.utils.aws.config import DatabaseTypes, RDSConfig
from numalogic.connectors.utils.aws.config import DatabaseTypes, RDSConnectionConfig

_LOGGER = logging.getLogger(__name__)

Expand All @@ -16,8 +16,8 @@ class MysqlFetcher(RDSBase):
"""
class that inherits from RDSBase. It is used to fetch data from a MySQL database.
- __init__(self, db_config: RDSConfig, **kwargs): Initializes the MysqlFetcher object with
the given RDSConfig and additional keyword arguments.
- __init__(self, db_config: RDSConnectionConfig, **kwargs): Initializes the MysqlFetcher object
with the given RDSConnectionConfig and additional keyword arguments.
The MysqlFetcher class is designed to be used as a base class for fetching data from a MySQL
database. It provides methods for establishing a connection, executing queries,
Expand All @@ -27,7 +27,7 @@ class that inherits from RDSBase. It is used to fetch data from a MySQL database

database_type = DatabaseTypes.MYSQL

def __init__(self, db_config: RDSConfig, **kwargs):
def __init__(self, db_config: RDSConnectionConfig, **kwargs):
super().__init__(db_config)
self.db_config = db_config
self.kwargs = kwargs
Expand All @@ -44,8 +44,8 @@ def get_connection(self) -> pymysql.Connection:
------
None
Notes: - If SSL/TLS is enabled and configured in the RDSConfig object, the connection
will be established with SSL/TLS. - If SSL/TLS is not enabled or configured,
Notes: - If SSL/TLS is enabled and configured in the RDSConnectionConfig object,
the connection will be established with SSL/TLS. - If SSL/TLS is not enabled or configured,
the connection will be established without SSL/TLS.
"""
Expand Down
4 changes: 2 additions & 2 deletions numalogic/connectors/utils/aws/boto3_client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from boto3 import Session
import logging

from numalogic.connectors.utils.aws.config import DatabaseServiceProvider, RDSConfig
from numalogic.connectors.utils.aws.config import DatabaseServiceProvider, RDSConnectionConfig
from numalogic.connectors.utils.aws.exceptions import UnRecognizedAWSClientException
from numalogic.connectors.utils.aws.sts_client_manager import STSClientManager

Expand Down Expand Up @@ -30,7 +30,7 @@ class Boto3ClientManager:
methods.
"""

def __init__(self, configurations: RDSConfig):
def __init__(self, configurations: RDSConnectionConfig):
self.rds_client = None
self.athena_client = None
self.configurations = configurations
Expand Down
2 changes: 1 addition & 1 deletion numalogic/connectors/utils/aws/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class RDBMSConfig:


@dataclass
class RDSConfig(AWSConfig, RDBMSConfig):
class RDSConnectionConfig(AWSConfig, RDBMSConfig):
"""
Class representing the configuration for an RDS (Relational Database Service) instance.
Expand Down
6 changes: 3 additions & 3 deletions numalogic/connectors/utils/aws/db_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from numalogic.tools.exceptions import ConfigNotFoundError
from omegaconf import OmegaConf
from numalogic.connectors.utils.aws.config import RDSConfig
from numalogic.connectors.utils.aws.config import RDSConnectionConfig

_LOGGER = logging.getLogger(__name__)


def load_db_conf(*paths: str) -> RDSConfig:
def load_db_conf(*paths: str) -> RDSConnectionConfig:
"""
Load database configuration from one or more YAML files.
Expand Down Expand Up @@ -38,6 +38,6 @@ def load_db_conf(*paths: str) -> RDSConfig:
_err_msg = f"None of the given conf paths exist: {paths}"
raise ConfigNotFoundError(_err_msg)

schema = OmegaConf.structured(RDSConfig)
schema = OmegaConf.structured(RDSConnectionConfig)
conf = OmegaConf.merge(schema, *confs)
return OmegaConf.to_object(conf)
8 changes: 7 additions & 1 deletion numalogic/tools/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class DataFormatError(Exception):


class DruidFetcherError(Exception):
"""Base class for all exceptions raised by the PrometheusFetcher class."""
"""Base class for all exceptions raised by the DruidFetcher class."""

pass


class RDSFetcherError(Exception):
"""Base class for all exceptions raised by the RDSFetcher class."""

pass
4 changes: 2 additions & 2 deletions numalogic/udfs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from logging import config as logconf
import os


from numalogic._constants import BASE_DIR
from numalogic.udfs._base import NumalogicUDF
from numalogic.udfs._config import StreamConf, PipelineConf, MLPipelineConf, load_pipeline_conf
Expand All @@ -11,7 +10,7 @@
from numalogic.udfs.inference import InferenceUDF
from numalogic.udfs.postprocess import PostprocessUDF
from numalogic.udfs.preprocess import PreprocessUDF
from numalogic.udfs.trainer import TrainerUDF, PromTrainerUDF, DruidTrainerUDF
from numalogic.udfs.trainer import TrainerUDF, PromTrainerUDF, DruidTrainerUDF, RDSTrainerUDF


def set_logger() -> None:
Expand All @@ -32,6 +31,7 @@ def set_logger() -> None:
"TrainerUDF",
"PromTrainerUDF",
"DruidTrainerUDF",
"RDSTrainerUDF",
"PostprocessUDF",
"UDFFactory",
"StreamConf",
Expand Down
3 changes: 3 additions & 0 deletions numalogic/udfs/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
RedisConf,
PrometheusConf,
DruidConf,
RDSConf,
)
from numalogic.tools.exceptions import ConfigNotFoundError

Expand Down Expand Up @@ -68,6 +69,7 @@ class PipelineConf:
registry_conf (Optional[RegistryInfo]): The configuration for the registry.
prometheus_conf (Optional[PrometheusConf]): The configuration for Prometheus.
druid_conf (Optional[DruidConf]): The configuration for Druid.
rds_conf (Optional[RDSConf]): The configuration for RDS.
"""

stream_confs: dict[str, StreamConf] = field(default_factory=dict)
Expand All @@ -77,6 +79,7 @@ class PipelineConf:
)
prometheus_conf: Optional[PrometheusConf] = None
druid_conf: Optional[DruidConf] = None
rds_conf: Optional[RDSConf] = None


def load_pipeline_conf(*paths: str) -> PipelineConf:
Expand Down
Loading

0 comments on commit f29f771

Please sign in to comment.