Skip to content

Commit

Permalink
feat: support both base conf and app conf (#328)
Browse files Browse the repository at this point in the history
- Add support for base config and app config merge
- Throw an error if no conf exists
- Paths can be altered using env vars

Signed-off-by: Avik Basu <[email protected]>
Co-authored-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 and ab93 committed Nov 30, 2023
1 parent 8b7f45f commit dfc383a
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 17 deletions.
3 changes: 3 additions & 0 deletions numalogic/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@
BASE_DIR = os.path.split(NUMALOGIC_DIR)[0]
TESTS_DIR = os.path.join(NUMALOGIC_DIR, "../tests")
BASE_CONF_DIR = os.path.join(BASE_DIR, "config")

DEFAULT_BASE_CONF_PATH = os.path.join(BASE_CONF_DIR, "default-configs", "config.yaml")
DEFAULT_APP_CONF_PATH = os.path.join(BASE_CONF_DIR, "app-configs", "config.yaml")
5 changes: 4 additions & 1 deletion numalogic/connectors/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

from numalogic.connectors._config import RedisConf
from numalogic.tools.exceptions import EnvVarNotFoundError
from numalogic.tools.exceptions import EnvVarNotFoundError, ConfigNotFoundError
from numalogic.tools.types import redis_client_t
from redis.backoff import ExponentialBackoff
from redis.exceptions import RedisClusterException, RedisError
Expand Down Expand Up @@ -86,6 +86,9 @@ def get_redis_client_from_conf(redis_conf: RedisConf, **kwargs) -> redis_client_
-------
Redis client instance
"""
if not redis_conf:
raise ConfigNotFoundError("RedisConf not found!")

auth = os.getenv("REDIS_AUTH")
if not auth:
raise EnvVarNotFoundError("REDIS_AUTH not set!")
Expand Down
16 changes: 8 additions & 8 deletions numalogic/udfs/__main__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import logging
import os
import sys
from typing import Final

from numalogic._constants import BASE_CONF_DIR
from numalogic._constants import DEFAULT_BASE_CONF_PATH, DEFAULT_APP_CONF_PATH
from numalogic.connectors.redis import get_redis_client_from_conf
from numalogic.udfs import load_pipeline_conf, UDFFactory, ServerFactory, set_logger

LOGGER = logging.getLogger(__name__)

# TODO support user config paths
CONF_FILE_PATH = os.getenv(
"CONF_PATH", default=os.path.join(BASE_CONF_DIR, "default-configs", "config.yaml")
)
BASE_CONF_FILE_PATH: Final[str] = os.getenv("BASE_CONF_PATH", default=DEFAULT_BASE_CONF_PATH)
APP_CONF_FILE_PATH: Final[str] = os.getenv("APP_CONF_PATH", default=DEFAULT_APP_CONF_PATH)


def init_server(step: str, server_type: str):
"""Initializes and returns the server."""
pipeline_conf = load_pipeline_conf(CONF_FILE_PATH)
LOGGER.info("Merging config with file paths: %s, %s", BASE_CONF_FILE_PATH, APP_CONF_FILE_PATH)
pipeline_conf = load_pipeline_conf(BASE_CONF_FILE_PATH, APP_CONF_FILE_PATH)
LOGGER.info("Pipeline config: %s", pipeline_conf)

redis_client = get_redis_client_from_conf(pipeline_conf.redis_conf)
Expand All @@ -25,7 +25,7 @@ def init_server(step: str, server_type: str):
return ServerFactory.get_server_instance(server_type, handler=udf)


def start_server():
def start_server() -> None:
"""Starts the pynumaflow server."""
set_logger()
step = sys.argv[1]
Expand All @@ -35,7 +35,7 @@ def start_server():
except (IndexError, TypeError):
server_type = "sync"

LOGGER.info("Running %s on %s server with config path %s", step, server_type, CONF_FILE_PATH)
LOGGER.info("Running %s on %s server", step, server_type)

server = init_server(step, server_type)
server.start()
Expand Down
22 changes: 19 additions & 3 deletions numalogic/udfs/_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from dataclasses import dataclass, field
from typing import Optional

Expand All @@ -11,6 +12,9 @@
PrometheusConf,
DruidConf,
)
from numalogic.tools.exceptions import ConfigNotFoundError

_logger = logging.getLogger(__name__)


@dataclass
Expand All @@ -34,8 +38,20 @@ class PipelineConf:
druid_conf: Optional[DruidConf] = None


def load_pipeline_conf(path: str) -> PipelineConf:
conf = OmegaConf.load(path)
def load_pipeline_conf(*paths: str) -> PipelineConf:
confs = []
for _path in paths:
try:
conf = OmegaConf.load(_path)
except FileNotFoundError:
_logger.warning("Config file path: %s not found. Skipping...", _path)
continue
confs.append(conf)

if not confs:
_err_msg = f"None of the given conf paths exist: {paths}"
raise ConfigNotFoundError(_err_msg)

schema = OmegaConf.structured(PipelineConf)
conf = OmegaConf.merge(schema, conf)
conf = OmegaConf.merge(schema, *confs)
return OmegaConf.to_object(conf)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.6.1.dev4"
version = "0.6.1.dev5"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
2 changes: 1 addition & 1 deletion tests/udfs/resources/_config3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ stream_confs:
accelerator: cpu
max_epochs: 5
default:
config_id: "odl-graphql"
config_id: "default"
source: "prometheus"
composite_keys: [ "namespace", "app" ]
metrics: ["namespace_app_rollouts_http_request_error_rate"]
Expand Down
29 changes: 29 additions & 0 deletions tests/udfs/resources/_config4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
stream_confs:
mycustomconf:
config_id: "mycustomconf"
source: "prometheus"
composite_keys: [ "namespace", "app" ]
metrics: [ "namespace_app_rollouts_cpu_utilization", "namespace_app_rollouts_http_request_error_rate", "namespace_app_rollouts_memory_utilization" ]
window_size: 12
numalogic_conf:
model:
name: "Conv1dVAE"
conf:
seq_len: 12
n_features: 3
latent_dim: 1
preprocess:
- name: "StandardScaler"
threshold:
name: "MahalanobisThreshold"
trainer:
train_hours: 3
min_train_size: 100
pltrainer_conf:
accelerator: cpu
max_epochs: 5
redis_conf:
url: "http:https://localhost:6222"
port: 26379
expiry: 360
master_name: "mymaster"
36 changes: 33 additions & 3 deletions tests/udfs/test_main.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,59 @@
import unittest
from unittest.mock import patch

from omegaconf import OmegaConf
from pynumaflow.mapper import Mapper, MultiProcMapper

from numalogic._constants import TESTS_DIR
from numalogic.tools.exceptions import ConfigNotFoundError

CONFIG_PATH = f"{TESTS_DIR}/udfs/resources/_config.yaml"
BASE_CONFIG_PATH = f"{TESTS_DIR}/udfs/resources/_config3.yaml"
APP_CONFIG_PATH = f"{TESTS_DIR}/udfs/resources/_config4.yaml"
REDIS_AUTH = "123"


class TestMainScript(unittest.TestCase):
@patch.dict("os.environ", {"CONF_PATH": CONFIG_PATH, "REDIS_AUTH": REDIS_AUTH})
@patch.dict("os.environ", {"BASE_CONF_PATH": BASE_CONFIG_PATH, "REDIS_AUTH": REDIS_AUTH})
def test_init_server_01(self):
from numalogic.udfs.__main__ import init_server

server = init_server("preprocess", "sync")
self.assertIsInstance(server, Mapper)

@patch.dict("os.environ", {"CONF_PATH": CONFIG_PATH, "REDIS_AUTH": REDIS_AUTH})
@patch.dict("os.environ", {"BASE_CONF_PATH": BASE_CONFIG_PATH, "REDIS_AUTH": REDIS_AUTH})
def test_init_server_02(self):
from numalogic.udfs.__main__ import init_server

server = init_server("inference", "multiproc")
self.assertIsInstance(server, MultiProcMapper)

def test_conf_loader(self):
from numalogic.udfs import load_pipeline_conf

plconf = load_pipeline_conf(BASE_CONFIG_PATH, APP_CONFIG_PATH)
base_conf = OmegaConf.load(BASE_CONFIG_PATH)
app_conf = OmegaConf.load(APP_CONFIG_PATH)

self.assertListEqual(
list(plconf.stream_confs),
list(base_conf["stream_confs"]) + list(app_conf["stream_confs"]),
)

def test_conf_loader_appconf_not_exist(self):
from numalogic.udfs import load_pipeline_conf

app_conf_path = "_random.yaml"
plconf = load_pipeline_conf(BASE_CONFIG_PATH, app_conf_path)
base_conf = OmegaConf.load(BASE_CONFIG_PATH)

self.assertListEqual(list(plconf.stream_confs), list(base_conf["stream_confs"]))

def test_conf_loader_err(self):
from numalogic.udfs import load_pipeline_conf

with self.assertRaises(ConfigNotFoundError):
load_pipeline_conf("_random1.yaml", "_random2.yaml")


if __name__ == "__main__":
unittest.main()

0 comments on commit dfc383a

Please sign in to comment.